Open In Colab

LICENSING NOTICE¶

Note that all users who use Vital DB, an open biosignal dataset, must agree to the Data Use Agreement below. If you do not agree, please close this window. The Data Use Agreement is available here: https://vitaldb.net/dataset/#h.vcpgs1yemdb5

This is the development version of the project code¶

For the Project Draft submission see the DL4H_Team_24_Project_Draft.ipynb notebook in the project repository.

Project repository¶

The project repository can be found at: https://github.com/abarrie2/cs598-dlh-project

Introduction¶

This project aims to reproduce findings from the paper titled "Predicting intraoperative hypotension using deep learning with waveforms of arterial blood pressure, electroencephalogram, and electrocardiogram: Retrospective study" by Jo Y-Y et al. (2022) [1]. This study introduces a deep learning model that predicts intraoperative hypotension (IOH) events before they occur, utilizing a combination of arterial blood pressure (ABP), electroencephalogram (EEG), and electrocardiogram (ECG) signals.

Background of the Problem¶

Intraoperative hypotension (IOH) is a common and significant surgical complication defined by a mean arterial pressure drop below 65 mmHg. It is associated with increased risks of myocardial infarction, acute kidney injury, and heightened postoperative mortality. Effective prediction and timely intervention can substantially enhance patient outcomes.

Evolution of IOH Prediction¶

Initial attempts to predict IOH primarily used arterial blood pressure (ABP) waveforms. A foundational study by Hatib F et al. (2018) titled "Machine-learning Algorithm to Predict Hypotension Based on High-fidelity Arterial Pressure Waveform Analysis" [2] showed that machine learning could forecast IOH events using ABP with reasonable accuracy. This finding spurred further research into utilizing various physiological signals for IOH prediction.

Subsequent advancements included the development of the Acumen™ hypotension prediction index, which was studied in "AcumenTM hypotension prediction index guidance for prevention and treatment of hypotension in noncardiac surgery: a prospective, single-arm, multicenter trial" by Bao X et al. (2024) [3]. This trial integrated a hypotension prediction index into blood pressure monitoring equipment, demonstrating its effectiveness in reducing the number and duration of IOH events during surgeries. Further study is needed to determine whether this resultant reduction in IOH events transalates into improved postoperative patient outcomes.

Current Study¶

Building on these advancements, the paper by Jo Y-Y et al. (2022) proposes a deep learning approach that enhances prediction accuracy by incorporating EEG and ECG signals along with ABP. This multi-modal method, evaluated over prediction windows of 3, 5, 10, and 15 minutes, aims to provide a comprehensive physiological profile that could predict IOH more accurately and earlier. Their results indicate that the combination of ABP and EEG significantly improves performance metrics such as AUROC and AUPRC, outperforming models that use fewer signals or different combinations.

Our project seeks to reproduce and verify Jo Y-Y et al.'s results to assess whether this integrated approach can indeed improve IOH prediction accuracy, thereby potentially enhancing surgical safety and patient outcomes.

Scope of Reproducibility:¶

The original paper investigated the following hypotheses:

  1. Hypothesis 1: A model using ABP and ECG will outperform a model using ABP alone in predicting IOH.
  2. Hypothesis 2: A model using ABP and EEG will outperform a model using ABP alone in predicting IOH.
  3. Hypothesis 3: A model using ABP, EEG, and ECG will outperform a model using ABP alone in predicting IOH.

Results were compared using AUROC and AUPRC scores. Based on the results described in the original paper, we expect that Hypothesis 2 will be confirmed, and that Hypotheses 1 and 3 will not be confirmed.

In order to perform the corresponding experiments, we will implement a CNN-based model that can be configured to train and infer using the following four model variations:

  1. ABP data alone
  2. ABP and ECG data
  3. ABP and EEG data
  4. ABP, ECG, and EEG data

We will measure the performance of these configurations using the same AUROC and AUPRC metrics as used in the original paper. To test hypothesis 1 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 2. To test hypothesis 2 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 3. To test hypothesis 3 we will compare the AUROC and AUPRC measures between model variation 1 and model variation 4. For all of the above measures and experiment combinations, we will operate multiple experiments where the time-to-IOH event prediction will use the following prediction windows:

  1. 3 minutes before event
  2. 5 minutes before event
  3. 10 minutes before event
  4. 15 minutes before event

In the event that we are compute-bound, we will prioritize the 3-minute prediction window experiments as they are the most relevant to the original paper's findings.

The predictive power of ABP, ECG and ABP + ECG models at 3-, 5-, 10- and 15-minute prediction windows: Predictive power of ABP, ECG and ABP + ECG models at 3-, 5-, 10- and 15-minute prediction windows

Modifications made for demo mode¶

In order to demonstrate the functioning of the code in a short (ie, <8 minute limit) the following options and modifications were used:

  1. MAX_CASES was set to 20. The total number of cases to be used in the full training set is 3296, but the smaller numbers allows demonstration of each section of the pipeline.
  2. vitaldb_cache is prepopulated in Google Colab. The cache file is approx. 800MB and contains the raw and mini-fied copies of the source dataset and is downloaded from Google Drive. This is much faster than using the vitaldb API, but is again only a fraction of the data. The full dataset can be downloaded with the API or prepopulated by following the instructions in the "Bulk Data Download" section below.
  3. max_epochs is set to 6. With the small dataset, training is fast and shows the decreasing training and validation losses. In the full model run, max_epochs will be set to 100. In both cases early stopping is enabled and will stop training if the validation losses stop decreasing for five consecutive epochs.
  4. Only the "ABP + EEG" combination will be run. In the final report, additional combinations will be run, as discussed later.
  5. Only the 3-minute prediction window will be run. In the final report, additional prediction windows (5, 10 and 15 minutes) will be run, as discussed later.
  6. No ablations are run in the demo. These will be completed for the final report.

Methodology¶

Methodology from Final Rubrik¶

  • Environment
    • Python version
    • Dependencies/packages needed
  • Data
    • Data download instruction
    • Data descriptions with helpful charts and visualizations
    • Preprocessing code + command
  • Model
    • Citation to the original paper
    • Link to the original paper’s repo (if applicable)
    • Model descriptions
    • Implementation code
    • Pretrained model (if applicable)
  • Training
    • Hyperparams
      • Report at least 3 types of hyperparameters such as learning rate, batch size, hidden size, dropout
    • Computational requirements
      • Report at least 3 types of requirements such as type of hardware, average runtime for each epoch, total number of trials, GPU hrs used, # training epochs
      • Training code
  • Evaluation
    • Metrics descriptions
    • Evaluation code

The methodology section is composed of the following subsections: Environment, Data and Model.

  • Environment: This section describes the setup of the environment, including the installation of necessary libraries and the configuration of the runtime environment.
  • Data: This section describes the dataset used in the study, including its collection and preprocessing.
    • Data Collection: This section describes the process of downloading the dataset from VitalDB and populating the local data cache.
    • Data Preprocessing: This section describes the preprocessing steps applied to the dataset, including data selection, data cleaning, and feature extraction.
  • Model: This section describes the deep learning model used in the study, including its implementation, training, and evaluation.
    • Model Implementation: This section describes the implementation of the deep learning model, including the architecture, loss function, and optimization algorithm.
    • Model Training: This section describes the training process, including the training loop, hyperparameters, and training strategy.
    • Model Evaluation: This section describes the evaluation process, including the metrics used, the evaluation strategy, and the results obtained.

Environment¶

Create environment¶

The environment setup differs based on whether you are running the code on a local machine or on Google Colab. The following sections provide instructions for setting up the environment in each case.

Local machine¶

Create conda environment for the project using the environment.yml file:

conda env create --prefix .envs/dlh-team24 -f environment.yml

Activate the environment with:

conda activate .envs/dlh-team24

Google Colab¶

The following code snippet installs the required packages and downloads the necessary files in a Google Colab environment:

In [1]:
# Google Colab environments have a `/content` directory. Use this as a proxy for running Colab-only code
COLAB_ENV = "google.colab" in str(get_ipython())
if COLAB_ENV:
    #install vitaldb
    %pip install vitaldb

    # Executing in Colab therefore download cached preprocessed data.
    # TODO: Integrate this with the setup local cache data section below.
    # Check for file existence before overwriting.
    import gdown
    gdown.download(id="15b5Nfhgj3McSO2GmkVUKkhSSxQXX14hJ", output="vitaldb_cache.tgz")
    !tar -zxf vitaldb_cache.tgz

    # Download sqi_filter.csv from github repo
    !wget https://raw.githubusercontent.com/abarrie2/cs598-dlh-project/main/sqi_filter.csv

All other required packages are already installed in the Google Colab environment.

Load environment¶

In [2]:
# Import packages
import os
import random
import copy
from collections import defaultdict

from timeit import default_timer as timer

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, precision_recall_curve, auc, confusion_matrix
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay, average_precision_score
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
import torch
from torch.utils.data import Dataset
import vitaldb
import h5py

import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from datetime import datetime

Set random seeds to generate consistent results:

In [3]:
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(RANDOM_SEED)

Data¶

Data Description¶

Source¶

Data for this project is sourced from the open biosignal VitalDB dataset as described in "VitalDB, a high-fidelity multi-parameter vital signs database in surgical patients" by Lee H-C et al. (2022) [4], which contains perioperative vital signs and numerical data from 6,388 cases of non-cardiac (general, thoracic, urological, and gynecological) surgery patients who underwent routine or emergency surgery at Seoul National University Hospital between 2016 and 2017. The dataset includes ABP, ECG, and EEG signals, as well as other physiological data. The dataset is available through an API and Python library, and at PhysioNet: https://physionet.org/content/vitaldb/1.0.0/

Statistics¶

Characteristics of the dataset: | Characteristic | Value | Details | |-----------------------|-----------------------------|------------------------| | Total number of cases | 6,388 | | | Sex (male) | 3,243 (50.8%) | | | Age (years) | 59 | Range: 48-68 | | Height (cm) | 162 | Range: 156-169 | | Weight (kg) | 61 | Range: 53-69 | | Tram-Rac 4A tracks | 6,355 (99.5%) | Sampling rate: 500Hz | | BIS Vista tracks | 5,566 (87.1%) | Sampling rate: 128Hz | | Case duration (min) | 189 | Range: 27-1041 |

Labels are only known after processing the data. In the original paper, there were an average of 1.6 IOH events per case and 5.7 non-events per case so we expect approximately 10,221 IOH events and 364,116 non-events in the dataset.

Data Processing¶

Data will be processed as follows:

  1. Load the dataset from VitalDB, or from a local cache if previously downloaded.
  2. Apply the inclusion and exclusion selection criteria to filter the dataset according to surgery metadata.
  3. Generate a minified dataset by discarding all tracks except ABP, ECG, and EEG.
  4. Preprocess the data by applying band-pass and z-score normalization to the ECG and EEG signals, and filtering out ABP signals below a Signal Quality Index (SQI) threshold.
  5. Generate event and non-event samples by extracting 60-second segments around IOH events and non-events.
  6. Split the dataset into training, validation, and test sets with a 6:1:3 ratio, ensuring that samples from a single case are not split across different sets to avoid data leakage.

Set Up Local Data Caches¶

VitalDB data is static, so local copies can be stored and reused to avoid expensive downloads and to speed up data processing.

The default directory defined below is in the project .gitignore file. If this is modified, the new directory should also be added to the project .gitignore.

In [4]:
VITALDB_CACHE = './vitaldb_cache'
VITAL_ALL = f"{VITALDB_CACHE}/vital_all"
VITAL_MINI = f"{VITALDB_CACHE}/vital_mini"
VITAL_METADATA = f"{VITALDB_CACHE}/metadata"
VITAL_MODELS = f"{VITALDB_CACHE}/models"
VITAL_PREPROCESS_SCRATCH = f"{VITALDB_CACHE}/data_scratch"
VITAL_EXTRACTED_SEGMENTS = f"{VITALDB_CACHE}/segments"
In [5]:
TRACK_CACHE = None
SEGMENT_CACHE = None

# when USE_MEMORY_CACHING is enabled, track data will be persisted in an in-memory cache. Not useful once we have already pre-extracted all event segments
# DON'T USE: Stores items in memory that are later not used. Causes OOM on segment extraction.
USE_MEMORY_CACHING = False

# When RESET_CACHE is set to True, it will ensure the TRACK_CACHE is disposed and recreated when we do dataset initialization.
# Use as a shortcut to wiping cache rather than restarting kernel
RESET_CACHE = False

PREDICTION_WINDOW = 10
#PREDICTION_WINDOW = 'ALL'

ALL_PREDICTION_WINDOWS = [3, 5, 10, 15]

# Maximum number of cases of interest for which to download data.
# Set to a small value (ex: 20) for demo purposes, else set to None to disable and download and process all.
MAX_CASES = None
#MAX_CASES = 200

# Preloading Cases: when true, all matched cases will have the _mini tracks extracted and put into in-mem dict
PRELOADING_CASES = False
PRELOADING_SEGMENTS = True
# Perform Data Preprocessing: do we want to take the raw vital file and extract segments of interest for training?
PERFORM_DATA_PREPROCESSING = False
In [6]:
if not os.path.exists(VITALDB_CACHE):
  os.mkdir(VITALDB_CACHE)
if not os.path.exists(VITAL_ALL):
  os.mkdir(VITAL_ALL)
if not os.path.exists(VITAL_MINI):
  os.mkdir(VITAL_MINI)
if not os.path.exists(VITAL_METADATA):
  os.mkdir(VITAL_METADATA)
if not os.path.exists(VITAL_MODELS):
  os.mkdir(VITAL_MODELS)
if not os.path.exists(VITAL_PREPROCESS_SCRATCH):
  os.mkdir(VITAL_PREPROCESS_SCRATCH)
if not os.path.exists(VITAL_EXTRACTED_SEGMENTS):
  os.mkdir(VITAL_EXTRACTED_SEGMENTS)

print(os.listdir(VITALDB_CACHE))
['segments_bak', '.DS_Store', 'vital_all', 'models', 'docs', 'vital_mini.tar', 'data_scratch', 'osfs', 'vital_mini', 'metadata', 'segments_bak_0428_00', 'segments', 'models_old']

Bulk Data Download¶

This step is not required, but will significantly speed up downstream processing and avoid a high volume of API requests to the VitalDB web site.

The cache population code checks if the .vital files are locally available, and can be populated by calling the vitaldb API or by manually prepopulating the cache (recommended)

  • Manually downloaded the dataset from the following site: https://physionet.org/content/vitaldb/1.0.0/
    • Download the zip file in a browser, or
    • Use wget -r -N -c -np https://physionet.org/files/vitaldb/1.0.0/ to download the files in a terminal
  • Move the contents of vital_files into the ${VITAL_ALL} directory.
In [7]:
# Returns the Pandas DataFrame for the specified dataset.
#   One of 'cases', 'labs', or 'trks'
# If the file exists locally, create and return the DataFrame.
# Else, download and cache the csv first, then return the DataFrame.
def vitaldb_dataframe_loader(dataset_name):
    if dataset_name not in ['cases', 'labs', 'trks']:
        raise ValueError(f'Invalid dataset name: {dataset_name}')
    file_path = f'{VITAL_METADATA}/{dataset_name}.csv'
    if os.path.isfile(file_path):
        print(f'{dataset_name}.csv exists locally.')
        df = pd.read_csv(file_path)
        return df
    else:
        print(f'downloading {dataset_name} and storing in the local cache for future reuse.')
        df = pd.read_csv(f'https://api.vitaldb.net/{dataset_name}')
        df.to_csv(file_path, index=False)
        return df

Exploratory Data Analysis¶

Cases¶

In [8]:
cases = vitaldb_dataframe_loader('cases')
cases = cases.set_index('caseid')
cases.shape
cases.csv exists locally.
Out[8]:
(6388, 73)
In [9]:
cases.index.nunique()
Out[9]:
6388
In [10]:
cases.head()
Out[10]:
subjectid casestart caseend anestart aneend opstart opend adm dis icu_days ... intraop_colloid intraop_ppf intraop_mdz intraop_ftn intraop_rocu intraop_vecu intraop_eph intraop_phe intraop_epi intraop_ca
caseid
1 5955 0 11542 -552 10848.0 1668 10368 -236220 627780 0 ... 0 120 0.0 100 70 0 10 0 0 0
2 2487 0 15741 -1039 14921.0 1721 14621 -221160 1506840 0 ... 0 150 0.0 0 100 0 20 0 0 0
3 2861 0 4394 -590 4210.0 1090 3010 -218640 40560 0 ... 0 0 0.0 0 50 0 0 0 0 0
4 1903 0 20990 -778 20222.0 2522 17822 -201120 576480 1 ... 0 80 0.0 100 100 0 50 0 0 0
5 4416 0 21531 -1009 22391.0 2591 20291 -67560 3734040 13 ... 0 0 0.0 0 160 0 10 900 0 2100

5 rows × 73 columns

In [11]:
cases['sex'].value_counts()
Out[11]:
sex
M    3243
F    3145
Name: count, dtype: int64

Tracks¶

In [12]:
trks = vitaldb_dataframe_loader('trks')
trks = trks.set_index('caseid')
trks.shape
trks.csv exists locally.
Out[12]:
(486449, 2)
In [13]:
trks.index.nunique()
Out[13]:
6388
In [14]:
trks.groupby('caseid')[['tid']].count().plot();
In [15]:
trks.groupby('caseid')[['tid']].count().hist();
In [16]:
trks.groupby('tname').count().sort_values(by='tid', ascending=False)
Out[16]:
tid
tname
Solar8000/HR 6387
Solar8000/PLETH_SPO2 6386
Solar8000/PLETH_HR 6386
Primus/CO2 6362
Primus/PAMB_MBAR 6361
... ...
Orchestra/AMD_VOL 1
Solar8000/ST_V5 1
Orchestra/NPS_VOL 1
Orchestra/AMD_RATE 1
Orchestra/VEC_VOL 1

196 rows × 1 columns

Parameters of Interest¶

Hemodynamic Parameters Reference¶

https://vitaldb.net/dataset/?query=overview#h.f7d712ycdpk2

SNUADC/ART

arterial blood pressure waveform

Parameter, Description, Type/Hz, Unit

SNUADC/ART, Arterial pressure wave, W/500, mmHg

In [17]:
trks[trks['tname'].str.contains('SNUADC/ART')].shape
Out[17]:
(3645, 2)

SNUADC/ECG_II

electrocardiogram waveform

Parameter, Description, Type/Hz, Unit

SNUADC/ECG_II, ECG lead II wave, W/500, mV

In [18]:
trks[trks['tname'].str.contains('SNUADC/ECG_II')].shape
Out[18]:
(6355, 2)

BIS/EEG1_WAV

electroencephalogram waveform

Parameter, Description, Type/Hz, Unit

BIS/EEG1_WAV, EEG wave from channel 1, W/128, uV

In [19]:
trks[trks['tname'].str.contains('BIS/EEG1_WAV')].shape
Out[19]:
(5871, 2)

Cases of Interest¶

These are the subset of case ids for which modelling and analysis will be performed based upon inclusion criteria and waveform data availability.

In [20]:
# TRACK NAMES is used for metadata analysis via API
TRACK_NAMES = ['SNUADC/ART', 'SNUADC/ECG_II', 'BIS/EEG1_WAV']
TRACK_SRATES = [500, 500, 128]
# EXTRACTION TRACK NAMES adds the EVENT track which is only used when doing actual file i/o
EXTRACTION_TRACK_NAMES = ['SNUADC/ART', 'SNUADC/ECG_II', 'BIS/EEG1_WAV', 'EVENT']
EXTRACTION_TRACK_SRATES = [500, 500, 128, 1]
In [21]:
# As in the paper, select cases which meet the following criteria:
#
# For patients, the inclusion criteria were as follows:
# (1) adults (age >= 18)
# (2) administered general anaesthesia
# (3) undergone non-cardiac surgery. 
#
# For waveform data, the inclusion criteria were as follows:
# (1) no missing monitoring for ABP, ECG, and EEG waveforms
# (2) no cases containing false events or non-events due to poor signal quality
#     (checked in second stage of data preprocessing)

# Adult
inclusion_1 = cases.loc[cases['age'] >= 18].index
print(f'{len(cases)-len(inclusion_1)} cases excluded, {len(inclusion_1)} remaining due to age criteria')

# General Anesthesia
inclusion_2 = cases.loc[cases['ane_type'] == 'General'].index
print(f'{len(cases)-len(inclusion_2)} cases excluded, {len(inclusion_2)} remaining due to anesthesia criteria')

# Non-cardiac surgery
inclusion_3 = cases.loc[
    ~cases['opname'].str.contains("cardiac", case=False)
    & ~cases['opname'].str.contains("aneurysmal", case=False)
].index
print(f'{len(cases)-len(inclusion_3)} cases excluded, {len(inclusion_3)} remaining due to non-cardiac surgery criteria')

# ABP, ECG, EEG waveforms
inclusion_4 = trks.loc[trks['tname'].isin(TRACK_NAMES)].index.value_counts()
inclusion_4 = inclusion_4[inclusion_4 == len(TRACK_NAMES)].index
print(f'{len(cases)-len(inclusion_4)} cases excluded, {len(inclusion_4)} remaining due to missing waveform data')

# SQI filter
# NOTE: this depends on a sqi_filter.csv generated by external processing
inclusion_5 = pd.read_csv('sqi_filter.csv', header=None, names=['caseid','sqi']).set_index('caseid').index
print(f'{len(cases)-len(inclusion_5)} cases excluded, {len(inclusion_5)} remaining due to SQI threshold not being met')

# Only include cases with known good waveforms.
exclusion_6 = pd.read_csv('malformed_tracks_filter.csv', header=None, names=['caseid']).set_index('caseid').index
inclusion_6 = cases.index.difference(exclusion_6)
print(f'{len(cases)-len(inclusion_6)} cases excluded, {len(inclusion_6)} remaining due to malformed waveforms')

cases_of_interest_idx = inclusion_1 \
    .intersection(inclusion_2) \
    .intersection(inclusion_3) \
    .intersection(inclusion_4) \
    .intersection(inclusion_5) \
    .intersection(inclusion_6)

cases_of_interest = cases.loc[cases_of_interest_idx]

print()
print(f'{cases_of_interest_idx.shape[0]} out of {cases.shape[0]} total cases remaining after exclusions applied')

# Trim cases of interest to MAX_CASES
if MAX_CASES:
    cases_of_interest_idx = cases_of_interest_idx[:MAX_CASES]
print(f'{cases_of_interest_idx.shape[0]} cases of interest selected')
57 cases excluded, 6331 remaining due to age criteria
345 cases excluded, 6043 remaining due to anesthesia criteria
14 cases excluded, 6374 remaining due to non-cardiac surgery criteria
3019 cases excluded, 3369 remaining due to missing waveform data
0 cases excluded, 6388 remaining due to SQI threshold not being met
186 cases excluded, 6202 remaining due to malformed waveforms

3110 out of 6388 total cases remaining after exclusions applied
3110 cases of interest selected
In [22]:
cases_of_interest.head(n=5)
Out[22]:
subjectid casestart caseend anestart aneend opstart opend adm dis icu_days ... intraop_colloid intraop_ppf intraop_mdz intraop_ftn intraop_rocu intraop_vecu intraop_eph intraop_phe intraop_epi intraop_ca
caseid
1 5955 0 11542 -552 10848.0 1668 10368 -236220 627780 0 ... 0 120 0.0 100 70 0 10 0 0 0
4 1903 0 20990 -778 20222.0 2522 17822 -201120 576480 1 ... 0 80 0.0 100 100 0 50 0 0 0
7 5124 0 15770 477 14817.0 3177 14577 -154320 623280 3 ... 0 0 0.0 0 120 0 0 0 0 0
10 2175 0 20992 -1743 21057.0 2457 19857 -220740 3580860 1 ... 0 90 0.0 0 110 0 20 500 0 600
12 491 0 31203 -220 31460.0 5360 30860 -208500 1519500 4 ... 200 100 0.0 100 70 0 20 0 0 3300

5 rows × 73 columns

Tracks of Interest¶

These are the subset of tracks (waveforms) for the cases of interest identified above.

In [23]:
# A single case maps to one or more waveform tracks. Select only the tracks required for analysis.
trks_of_interest = trks.loc[cases_of_interest_idx][trks.loc[cases_of_interest_idx]['tname'].isin(TRACK_NAMES)]
trks_of_interest.shape
Out[23]:
(9330, 2)
In [24]:
trks_of_interest.head(n=5)
Out[24]:
tname tid
caseid
1 BIS/EEG1_WAV 0aa685df768489a18a5e9f53af0d83bf60890c73
1 SNUADC/ART 724cdd7184d7886b8f7de091c5b135bd01949959
1 SNUADC/ECG_II 8c9161aaae8cb578e2aa7b60f44234d98d2b3344
4 BIS/EEG1_WAV 1b4c2379be3397a79d3787dd810190150dc53f27
4 SNUADC/ART e28777c4706fe3a5e714bf2d91821d22d782d802
In [25]:
trks_of_interest_idx = trks_of_interest.set_index('tid').index
trks_of_interest_idx.shape
Out[25]:
(9330,)

Build Tracks Cache for Local Processing¶

Tracks data are large and therefore expensive to download every time used. By default, the .vital file format stores all tracks for each case internally. Since only select tracks per case are required, each .vital file can be further reduced by discarding the unused tracks.

In [26]:
# Ensure the full vital file dataset is available for cases of interest.
count_downloaded = 0
count_present = 0

#for i, idx in enumerate(cases.index):
for idx in cases_of_interest_idx:
    full_path = f'{VITAL_ALL}/{idx:04d}.vital'
    if not os.path.isfile(full_path):
        print(f'Missing vital file: {full_path}')
        # Download and save the file.
        vf = vitaldb.VitalFile(idx)
        vf.to_vital(full_path)
        count_downloaded += 1
    else:
        count_present += 1

print()
print(f'Count of cases of interest:           {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files downloaded:      {count_downloaded}')
print(f'Count of vital files already present: {count_present}')
Count of cases of interest:           3110
Count of vital files downloaded:      0
Count of vital files already present: 3110
In [27]:
# Convert vital files to "mini" versions including only the subset of tracks defined in TRACK_NAMES above.
# Only perform conversion for the cases of interest.
# NOTE: If this cell is interrupted, it can be restarted and will continue where it left off.
count_minified = 0
count_present = 0
count_missing_tracks = 0
count_not_fixable = 0

vf = vitaldb.VitalFile('./vitaldb_cache/vital_all/0001.vital', EXTRACTION_TRACK_NAMES)
print(vf)

# If set to true, local mini files are checked for all tracks even if already present.
FORCE_VALIDATE = False

for idx in cases_of_interest_idx:
    full_path = f'{VITAL_ALL}/{idx:04d}.vital'
    mini_path = f'{VITAL_MINI}/{idx:04d}_mini.vital'

    if FORCE_VALIDATE or not os.path.isfile(mini_path):
        print(f'Creating mini vital file: {idx}')
        vf = vitaldb.VitalFile(full_path, EXTRACTION_TRACK_NAMES)
        
        if len(vf.get_track_names()) != 4:
            print(f'Missing track in vital file: {idx}, {set(EXTRACTION_TRACK_NAMES).difference(set(vf.get_track_names()))}')
            count_missing_tracks += 1
            
            # Attempt to download from VitalDB directly and see if missing tracks are present.
            vf = vitaldb.VitalFile(idx, EXTRACTION_TRACK_NAMES)
            
            if len(vf.get_track_names()) != 3:
                print(f'Unable to fix missing tracks: {idx}')
                count_not_fixable += 1
                continue
                
            if vf.get_track_samples(EXTRACTION_TRACK_NAMES[0], 1/EXTRACTION_TRACK_SRATES[0]).shape[0] == 0:
                print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[0]}')
                count_not_fixable += 1
                continue
                
            if vf.get_track_samples(EXTRACTION_TRACK_NAMES[1], 1/EXTRACTION_TRACK_SRATES[1]).shape[0] == 0:
                print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[1]}')
                count_not_fixable += 1
                continue
                
            if vf.get_track_samples(EXTRACTION_TRACK_NAMES[2], 1/EXTRACTION_TRACK_SRATES[2]).shape[0] == 0:
                print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[2]}')
                count_not_fixable += 1
                continue

            # if vf.get_track_samples(EXTRACTION_TRACK_NAMES[3], 1/EXTRACTION_TRACK_SRATES[3]).shape[0] == 0:
            #     print(f'Empty track: {idx}, {EXTRACTION_TRACK_NAMES[3]}')
            #     count_not_fixable += 1
            #     continue

        vf.to_vital(mini_path)
        count_minified += 1
    else:
        count_present += 1

print()
print(f'Count of cases of interest:           {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files minified:        {count_minified}')
print(f'Count of vital files already present: {count_present}')
print(f'Count of vital files missing tracks:  {count_missing_tracks}')
print(f'Count of vital files not fixable:     {count_not_fixable}')
VitalFile('./vitaldb_cache/vital_all/0001.vital', '['EVENT', 'SNUADC/ART', 'SNUADC/ECG_II', 'BIS/EEG1_WAV']')

Count of cases of interest:           3110
Count of vital files minified:        0
Count of vital files already present: 3110
Count of vital files missing tracks:  0
Count of vital files not fixable:     0

Validate Mini Files¶

In [28]:
# Convert vital files to "mini" versions including only the subset of tracks defined in TRACK_NAMES above.
# Only perform conversion for the cases of interest.
# NOTE: If this cell is interrupted, it can be restarted and will continue where it left off.
count_missing_tracks = 0

# If true, perform fast validate that all mini files have 3 tracks.
FORCE_VALIDATE = False

if FORCE_VALIDATE:
    for idx in cases_of_interest_idx:
        mini_path = f'{VITAL_MINI}/{idx:04d}_mini.vital'

        if os.path.isfile(mini_path):
            vf = vitaldb.VitalFile(mini_path)

            if len(vf.get_track_names()) != 3:
                print(f'Missing track in vital file: {idx}, {set(TRACK_NAMES).difference(set(vf.get_track_names()))}')
                count_missing_tracks += 1

print()
print(f'Count of cases of interest:           {cases_of_interest_idx.shape[0]}')
print(f'Count of vital files missing tracks:  {count_missing_tracks}')
Count of cases of interest:           3110
Count of vital files missing tracks:  0

Filtering¶

Preprocessing characteristics are different for each of the three signal categories:

  • ABP: no preprocessing, use as-is
  • ECG: apply a 1-40Hz bandpass filter, then perform Z-score normalization
  • EEG: apply a 0.5-50Hz bandpass filter

apply_bandpass_filter() implements the bandpass filter using scipy.signal

apply_zscore_normalization() implements the Z-score normalization using numpy

In [29]:
from scipy.signal import butter, lfilter, spectrogram

# define two methods for data preprocessing

def apply_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter(order, [lowcut, highcut], fs=fs, btype='band')
    y = lfilter(b, a, np.nan_to_num(data))
    return y

def apply_zscore_normalization(signal):
    mean = np.nanmean(signal)
    std = np.nanstd(signal)
    return (signal - mean) / std
In [30]:
# Filtering Demonstration

# temp experimental, code to be incorporated into overall preloader process
# for now it's just dumping example plots of the before/after filtered signal data
caseidx = 1
file_path = f"{VITAL_MINI}/{caseidx:04d}_mini.vital"
vf = vitaldb.VitalFile(file_path, TRACK_NAMES)

originalAbp = None
filteredAbp = None
originalEcg = None
filteredEcg = None
originalEeg = None
filteredEeg = None

ABP_TRACK_NAME = "SNUADC/ART"
ECG_TRACK_NAME = "SNUADC/ECG_II"
EEG_TRACK_NAME = "BIS/EEG1_WAV"

for i, (track_name, rate) in enumerate(zip(TRACK_NAMES, TRACK_SRATES)):
    # Get samples for this track
    track_samples = vf.get_track_samples(track_name, 1/rate)
    #track_samples, _ = vf.get_samples(track_name, 1/rate)
    print(f"Track {track_name} @ {rate}Hz shape {len(track_samples)}")

    if track_name == ABP_TRACK_NAME:
        # ABP waveforms are used without further pre-processing
        originalAbp = track_samples
        filteredAbp = track_samples
    elif track_name == ECG_TRACK_NAME:
        originalEcg = track_samples
        # ECG waveforms are band-pass filtered between 1 and 40 Hz, and Z-score normalized
        # first apply bandpass filter
        filteredEcg = apply_bandpass_filter(track_samples, 1, 40, rate)
        # then do z-score normalization
        filteredEcg = apply_zscore_normalization(filteredEcg)
    elif track_name == EEG_TRACK_NAME:
        # EEG waveforms are band-pass filtered between 0.5 and 50 Hz
        originalEeg = track_samples
        filteredEeg = apply_bandpass_filter(track_samples, 0.5, 50, rate, 2)

def plotSignal(data, title):
    plt.figure(figsize=(20, 5))
    plt.plot(data)
    plt.title(title)
    plt.show()

plotSignal(originalAbp, "Original ABP")
plotSignal(originalAbp, "Unfiltered ABP")
plotSignal(originalEcg, "Original ECG")
plotSignal(filteredEcg, "Filtered ECG")
plotSignal(originalEeg, "Original EEG")
plotSignal(filteredEeg, "Filtered EEG")
Track SNUADC/ART @ 500Hz shape 5770575
Track SNUADC/ECG_II @ 500Hz shape 5770575
Track BIS/EEG1_WAV @ 128Hz shape 1477268
In [31]:
# Preprocess data tracks
ABP_TRACK_NAME = "SNUADC/ART"
ECG_TRACK_NAME = "SNUADC/ECG_II"
EEG_TRACK_NAME = "BIS/EEG1_WAV"
EVENT_TRACK_NAME = "EVENT"
MINI_FILE_FOLDER = VITAL_MINI
CACHE_FILE_FOLDER = VITAL_PREPROCESS_SCRATCH

if RESET_CACHE:
    TRACK_CACHE = None
    SEGMENT_CACHE = None

if TRACK_CACHE is None:
    TRACK_CACHE = {}
    SEGMENT_CACHE = {}

def get_track_data(case, print_when_file_loaded = False):
    parsedFile = None
    abp = None
    eeg = None
    ecg = None
    events = None

    for i, (track_name, rate) in enumerate(zip(EXTRACTION_TRACK_NAMES, EXTRACTION_TRACK_SRATES)):
        # use integer case id and track name, delimited by pipe, as cache key
        cache_label = f"{case}|{track_name}"
        if cache_label not in TRACK_CACHE:
            if parsedFile is None:
                file_path = f"{MINI_FILE_FOLDER}/{case:04d}_mini.vital"
                if print_when_file_loaded:
                    print(f"[{datetime.now()}] Loading vital file {file_path}")
                parsedFile = vitaldb.VitalFile(file_path, EXTRACTION_TRACK_NAMES)
            dataset = np.array(parsedFile.get_track_samples(track_name, 1/rate))
            if track_name == ABP_TRACK_NAME:
                # no filtering for ABP
                abp = dataset
                abp = pd.DataFrame(abp).ffill(axis=0).bfill(axis=0)[0].values
                if USE_MEMORY_CACHING:
                    TRACK_CACHE[cache_label] = abp
            elif track_name == ECG_TRACK_NAME:
                ecg = dataset
                # apply ECG filtering: first bandpass then do z-score normalization
                ecg = pd.DataFrame(ecg).ffill(axis=0).bfill(axis=0)[0].values
                ecg = apply_bandpass_filter(ecg, 1, 40, rate, 2)
                ecg = apply_zscore_normalization(ecg)
                
                if USE_MEMORY_CACHING:
                    TRACK_CACHE[cache_label] = ecg
            elif track_name == EEG_TRACK_NAME:
                eeg = dataset
                eeg = pd.DataFrame(eeg).ffill(axis=0).bfill(axis=0)[0].values
                # apply EEG filtering: bandpass only
                eeg = apply_bandpass_filter(eeg, 0.5, 50, rate, 2)
                if USE_MEMORY_CACHING:
                    TRACK_CACHE[cache_label] = eeg
            elif track_name == EVENT_TRACK_NAME:
                events = dataset
                if USE_MEMORY_CACHING:
                    TRACK_CACHE[cache_label] = events
        else:
            # cache hit, pull from cache
            if track_name == ABP_TRACK_NAME:
                abp = TRACK_CACHE[cache_label]
            elif track_name == ECG_TRACK_NAME:
                ecg = TRACK_CACHE[cache_label]
            elif track_name == EEG_TRACK_NAME:
                eeg = TRACK_CACHE[cache_label]
            elif track_name == EVENT_TRACK_NAME:
                events = TRACK_CACHE[cache_label]

    return (abp, ecg, eeg, events)

# ABP waveforms are used without further pre-processing
# ECG waveforms are band-pass filtered between 1 and 40 Hz, and Z-score normalized
# EEG waveforms are band-pass filtered between 0.5 and 50 Hz
if PRELOADING_CASES:
    # determine disk cache file label
    maxlabel = "ALL"
    if MAX_CASES is not None:
        maxlabel = str(MAX_CASES)
    picklefile = f"{CACHE_FILE_FOLDER}/{PREDICTION_WINDOW}_minutes_MAX{maxlabel}.trackcache"

    for track in tqdm(cases_of_interest_idx):
        # getting track data will cause a cache-check and fill when missing
        # will also apply appropriate filtering per track
        get_track_data(track, False)
    
    print(f"Generated track cache, {len(TRACK_CACHE)} records generated")


def get_segment_data(file_path):
    abp = None
    eeg = None
    ecg = None

    if USE_MEMORY_CACHING:
        if file_path in SEGMENT_CACHE:
            (abp, ecg, eeg) = SEGMENT_CACHE[file_path]
            return (abp, ecg, eeg)

    try:
        with h5py.File(file_path, 'r') as f:
            abp = np.array(f['abp'])
            ecg = np.array(f['ecg'])
            eeg = np.array(f['eeg'])
        
        abp = np.array(abp)
        eeg = np.array(eeg)
        ecg = np.array(ecg)

        if len(abp) > 30000:
            abp = abp[:30000]
        elif len(ecg) < 30000:
            abp = np.resize(abp, (30000))

        if len(ecg) > 30000:
            ecg = ecg[:30000]
        elif len(ecg) < 30000:
            ecg = np.resize(ecg, (30000))

        if len(eeg) > 7680:
            eeg = eeg[:7680]
        elif len(eeg) < 7680:
            eeg = np.resize(eeg, (7680))

        if USE_MEMORY_CACHING:
            SEGMENT_CACHE[file_path] = (abp, ecg, eeg)
    except:
        abp = None
        ecg = None
        eeg = None

    return (abp, ecg, eeg)

The following method is adapted from the preprocessing block of reference [6] (https://github.com/vitaldb/examples/blob/master/hypotension_art.ipynb)

The approach first finds an interoperative hypotensive event in the ABP waveform. It then backtracks to earlier in the waveform to extract a 60 second segment representing the waveform feature to use as model input. The figure below shows an example of this approach and is reproduced from the VitalDB example notebook referenced above.

Feature segment extraction

In [32]:
def getSurgeryBoundariesInSeconds(event, debug=False):
    eventIndices = np.argwhere(event==event)
    # we are looking for the last index where the string contains 'start
    lastStart = 0
    firstFinish = len(event)-1
    
    # find last start
    for idx in eventIndices:
        if 'started' in event[idx[0]]:
            if debug:
                print(event[idx[0]])
                print(idx[0])
            lastStart = idx[0]
    
    # find first finish
    for idx in eventIndices:
        if 'finish' in event[idx[0]]:
            if debug:
                print(event[idx[0]])
                print(idx[0])

            firstFinish = idx[0]
            break
    
    if debug:
        print(f'lastStart, firstFinish: {lastStart}, {firstFinish}')
    return (lastStart, firstFinish)
In [33]:
def areCaseSegmentsCached(caseid):
    seg_folder = f"{VITAL_EXTRACTED_SEGMENTS}/{caseid:04d}"
    return os.path.exists(seg_folder) and len(os.listdir(seg_folder)) > 0
In [34]:
def isAbpSegmentValidNumpy(samples, debug=False):
    valid = True
    if np.isnan(samples).mean() > 0.1:
        valid = False
        if debug:
            print(f">10% NaN")
    elif (samples > 200).any():
        valid = False
        if debug:
            print(f"Presence of BP > 200")
    elif (samples < 30).any():
        valid = False
        if debug:
            print(f"Presence of BP < 30")
    elif np.max(samples) - np.min(samples) < 30:
        if debug:
            print(f"Max - Min test < 30")
        valid = False
    elif (np.abs(np.diff(samples)) > 30).any():  # abrupt change -> noise
        if debug:
            print(f"Abrupt change (noise)")
        valid = False
    
    return valid
In [35]:
def isAbpSegmentValid(vf, debug=False):
    ABP_ECG_SRATE_HZ = 500
    ABP_TRACK_NAME = "SNUADC/ART"

    samples = np.array(vf.get_track_samples(ABP_TRACK_NAME, 1/ABP_ECG_SRATE_HZ))
    return isAbpSegmentValidNumpy(samples, debug)
In [36]:
def saveCaseSegments(caseid, positiveSegments, negativeSegments, compresslevel=9, debug=False, forceWrite=False):
    if len(positiveSegments) == 0 and len(negativeSegments) == 0:
        # exit early if no events found
        print(f'{caseid}: exit early, no segments to save')
        return

    # event composition
    # predictiveSegmentStart in seconds, predictiveSegmentEnd in seconds, predWindow (0 for negative), abp, ecg, eeg)
    # 0start, 1end, 2predwindow, 3abp, 4ecg, 5eeg

    seg_folder = f"{VITAL_EXTRACTED_SEGMENTS}/{caseid:04d}"
    if not os.path.exists(seg_folder):
        # if directory needs to be created, then there are no cached segments
        os.mkdir(seg_folder)
    else:
        if not forceWrite:
            # exit early if folder already exists, case already produced
            return

    # prior to writing files out, clear existing files
    for filename in os.listdir(seg_folder):
        file_path = os.path.join(seg_folder, filename)
        if debug:
            print(f'deleting: {file_path}')
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print('Failed to delete %s. Reason: %s' % (file_path, e))
    
    count_pos_saved = 0
    for i in range(0, len(positiveSegments)):
        event = positiveSegments[i]
        startIndex = event[0]
        endIndex = event[1]
        predWindow = event[2]
        abp = event[3]
        #ecg = event[4]
        #eeg = event[5]

        seg_filename = f"{caseid:04d}_{startIndex}_{predWindow:02d}_True.h5"
        seg_fullpath = f"{seg_folder}/{seg_filename}"
        if isAbpSegmentValidNumpy(abp, debug):
            count_pos_saved += 1

            abp = abp.tolist()
            ecg = event[4].tolist()
            eeg = event[5].tolist()
        
            f = h5py.File(seg_fullpath, "w")
            f.create_dataset('abp', data=abp, compression="gzip", compression_opts=compresslevel)
            f.create_dataset('ecg', data=ecg, compression="gzip", compression_opts=compresslevel)
            f.create_dataset('eeg', data=eeg, compression="gzip", compression_opts=compresslevel)
            
            f.flush()
            f.close()
            f = None

            abp = None
            ecg = None
            eeg = None

            # f.create_dataset('label', data=[1], compression="gzip", compression_opts=compresslevel)
            # f.create_dataset('pred_window', data=[event[2]], compression="gzip", compression_opts=compresslevel)
            # f.create_dataset('caseid', data=[caseid], compression="gzip", compression_opts=compresslevel)
        elif debug:
            print(f"{caseid:04d} {predWindow:02d}min {startIndex} starttime = ignored, segment validity issues")

    count_neg_saved = 0
    for i in range(0, len(negativeSegments)):
        event = negativeSegments[i]
        startIndex = event[0]
        endIndex = event[1]
        predWindow = event[2]
        abp = event[3]
        #ecg = event[4]
        #eeg = event[5]

        seg_filename = f"{caseid:04d}_{startIndex}_0_False.h5"
        seg_fullpath = f"{seg_folder}/{seg_filename}"
        if isAbpSegmentValidNumpy(abp, debug):
            count_neg_saved += 1

            abp = abp.tolist()
            ecg = event[4].tolist()
            eeg = event[5].tolist()
            
            f = h5py.File(seg_fullpath, "w")
            f.create_dataset('abp', data=abp, compression="gzip", compression_opts=compresslevel)
            f.create_dataset('ecg', data=ecg, compression="gzip", compression_opts=compresslevel)
            f.create_dataset('eeg', data=eeg, compression="gzip", compression_opts=compresslevel)
            
            f.flush()
            f.close()
            f = None

            abp = None
            ecg = None
            eeg = None

            # f.create_dataset('label', data=[0], compression="gzip", compression_opts=compresslevel)
            # f.create_dataset('pred_window', data=[0], compression="gzip", compression_opts=compresslevel)
            # f.create_dataset('caseid', data=[caseid], compression="gzip", compression_opts=compresslevel)
        elif debug:
            print(f"{caseid:04d} CleanWindow {startIndex} starttime = ignored, segment validity issues")
            
    if count_neg_saved == 0 and count_pos_saved == 0:
        print(f'{caseid}: nothing saved, all segments filtered')
In [37]:
# Generate hypotensive events
# Hypotensive events are defined as a 1-minute interval with sustained ABP of less than 65 mmHg
# Note: Hypotensive events should be at least 20 minutes apart to minimize potential residual effects from previous events
# Generate hypotension non-events
# To sample non-events, 30-minute segments where the ABP was above 75 mmHG were selected, and then
# three one-minute samples of each waveform were obtained from the middle of the segment
# both occur in extract_segments
#VITAL_EXTRACTED_SEGMENTS
def extract_segments(cases_of_interest_idx, debug=False, checkCache=True, forceWrite=False, returnSegments=False):
    # Sampling rate for ABP and ECG, Hz. These rates should be the same. Default = 500
    ABP_ECG_SRATE_HZ = 500

    # Sampling rate for EEG. Default = 128
    EEG_SRATE_HZ = 128

    # Final dataset for training and testing the model.
    positiveSegmentsMap = {}
    negativeSegmentsMap = {}
    iohEventsMap = {}
    cleanEventsMap = {}

    # Process each case and extract segments. For each segment identify presence of an event in the label zone.
    count_cases = len(cases_of_interest_idx)

    #for case_count, caseid in tqdm(enumerate(cases_of_interest_idx), total=count_cases):
    for case_count, caseid in enumerate(cases_of_interest_idx):
        if debug:
            print(f'Loading case: {caseid:04d}, ({case_count + 1} of {count_cases})')

        if checkCache and areCaseSegmentsCached(caseid):
            if debug:
                print(f'Skipping case: {caseid:04d}, already cached')
            # skip records we've already cached
            continue

        # read the arterial waveform
        (abp, ecg, eeg, event) = get_track_data(caseid)
        if debug:
            print(f'Length of {TRACK_NAMES[0]}:       {abp.shape[0]}')
            print(f'Length of {TRACK_NAMES[1]}:    {ecg.shape[0]}')
            print(f'Length of {TRACK_NAMES[2]}:     {eeg.shape[0]}')

        (startInSeconds, endInSeconds) = getSurgeryBoundariesInSeconds(event)
        if debug:
            print(f"Event markers indicate that surgery begins at {startInSeconds}s and ends at {endInSeconds}s.")

        track_length_seconds = int(len(abp) / ABP_ECG_SRATE_HZ)
        if debug:
            print(f"Processing case {caseid} with length {track_length_seconds}s")

        
        # check if the ABP segment in the surgery window is valid
        if debug:
            isSurgerySegmentValid = isAbpSegmentValidNumpy(abp[startInSeconds:endInSeconds])
            print(f'{caseid}: surgery segment valid: {isSurgerySegmentValid}')
        
        iohEvents = []
        cleanEvents = []
        i = 0
        started = False
        eofReached = False
        trackStartIndex = None

        # set i pointer (which operates in seconds) to start marker for surgery
        i = startInSeconds

        # FIRST PASS
        # in the first forward pass, we are going to identify the start/end boundaries of all IOH events within the case
        while i < track_length_seconds - 60 and i < endInSeconds:
            segmentStart = None
            segmentEnd = None
            segFound = False

            # look forward one minute
            abpSeg = abp[i * ABP_ECG_SRATE_HZ:(i + 60) * ABP_ECG_SRATE_HZ]

            # roll forward until we hit a one minute window where mean ABP >= 65 so we know leads are connected and it's tracking
            if not started:
                if np.nanmean(abpSeg) >= 65:
                    started = True
                    trackStartIndex = i
            # if we're started and mean abp for the window is <65, we are starting a new IOH event
            elif np.nanmean(abpSeg) < 65:
                segmentStart = i
                # now seek forward to find end of event, perpetually checking the lats minute of the IOH event
                for j in range(i + 60, track_length_seconds):
                    # look backward one minute
                    abpSegForward = abp[(j - 60) * ABP_ECG_SRATE_HZ:j * ABP_ECG_SRATE_HZ]
                    if np.nanmean(abpSegForward) >= 65:
                        segmentEnd = j - 1
                        break
                if segmentEnd is None:
                    eofReached = True
                else:
                    # otherwise, end of the IOH segment has been reached, record it
                    iohEvents.append((segmentStart, segmentEnd))
                    segFound = True
                    
                    if debug:
                        t_abp = abp[segmentStart * ABP_ECG_SRATE_HZ:segmentEnd * ABP_ECG_SRATE_HZ]
                        isIohSegmentValid = isAbpSegmentValidNumpy(t_abp)
                        print(f'{caseid}: ioh segment valid: {isIohSegmentValid}, {segmentStart}, {segmentEnd}, {t_abp.shape}')

            i += 1
            if not started:
                continue
            elif eofReached:
                break
            elif segFound:
                i = segmentEnd + 1

        # SECOND PASS
        # in the second forward pass, we are going to identify the start/end boundaries of all non-overlapping 30 minute "clean" windows
        # reuse the 'start of signal' index from our first pass
        if trackStartIndex is None:
            trackStartIndex = startInSeconds
        i = trackStartIndex
        eofReached = False

        while i < track_length_seconds - 1800 and i < endInSeconds:
            segmentStart = None
            segmentEnd = None
            segFound = False

            startIndex = i
            endIndex = i + 1800

            # check to see if this 30 minute window overlaps any IOH events, if so ffwd to end of latest overlapping IOH
            overlapFound = False
            latestEnd = None
            for event in iohEvents:
                # case 1: starts during an event
                if startIndex >= event[0] and startIndex < event[1]:
                    latestEnd = event[1]
                    overlapFound = True
                # case 2: ends during an event
                elif endIndex >= event[0] and endIndex < event[1]:
                    latestEnd = event[1]
                    overlapFound = True
                # case 3: event occurs entirely inside of the window
                elif startIndex < event[0] and endIndex > event[1]:
                    latestEnd = event[1]
                    overlapFound = True

            # FFWD if we found an overlap
            if overlapFound:
                i = latestEnd + 1
                continue

            # look forward 30 minutes
            abpSeg = abp[startIndex * ABP_ECG_SRATE_HZ:endIndex * ABP_ECG_SRATE_HZ]

            # if we're started and mean abp for the window is >= 75, we are starting a new clean event
            if np.nanmean(abpSeg) >= 75:
                overlapFound = False
                latestEnd = None
                for event in iohEvents:
                    # case 1: starts during an event
                    if startIndex >= event[0] and startIndex < event[1]:
                        latestEnd = event[1]
                        overlapFound = True
                    # case 2: ends during an event
                    elif endIndex >= event[0] and endIndex < event[1]:
                        latestEnd = event[1]
                        overlapFound = True
                    # case 3: event occurs entirely inside of the window
                    elif startIndex < event[0] and endIndex > event[1]:
                        latestEnd = event[1]
                        overlapFound = True

                if not overlapFound:
                    segFound = True
                    segmentEnd = endIndex
                    cleanEvents.append((startIndex, endIndex))
                    
                    if debug:
                        t_abp = abp[startIndex * ABP_ECG_SRATE_HZ:endIndex * ABP_ECG_SRATE_HZ]
                        isCleanSegmentValid = isAbpSegmentValidNumpy(t_abp)
                        print(f'{caseid}: clean segment valid: {isCleanSegmentValid}, {startIndex}, {endIndex}, {t_abp.shape}')

            i += 10
            if segFound:
                i = segmentEnd + 1

        if debug:
            print(f"IOH Events for case {caseid}: {iohEvents}")
            print(f"Clean Events for case {caseid}: {cleanEvents}")

        positiveSegments = []
        negativeSegments = []

        # THIRD PASS
        # in the third pass, we will use the collections of ioh event windows to generate our actual extracted segments based on our prediction window (positive labels)
        for i in range(0, len(iohEvents)):
            if debug:
                print(f"Checking event {iohEvents[i]}")
            # we want to review current event boundaries, as well as previous event boundaries if available
            event = iohEvents[i]
            previousEvent = None
            if i > 0:
                previousEvent = iohEvents[i - 1]

            for predWindow in ALL_PREDICTION_WINDOWS:
                if debug:
                    print(f"Checking event {iohEvents[i]} for pred {predWindow}")
                iohEventStart = event[0]
                predictiveSegmentEnd = event[0] - (predWindow*60)
                predictiveSegmentStart = predictiveSegmentEnd - 60

                if (predictiveSegmentStart < 0):
                    # don't rewind before the beginning of the track
                    if debug:
                        print(f"Checking event {iohEvents[i]} for pred {predWindow} - exit, before beginning")
                    continue
                elif (predictiveSegmentStart < trackStartIndex):
                    # don't rewind before the beginning of signal in track
                    if debug:
                        print(f"Checking event {iohEvents[i]} for pred {predWindow} - exit, before track start")
                    continue
                elif previousEvent is not None:
                    # does this event window come before or during the previous event?
                    overlapFound = False
                    # case 1: starts during an event
                    if predictiveSegmentStart >= previousEvent[0] and predictiveSegmentStart < previousEvent[1]:
                        overlapFound = True
                    # case 2: ends during an event
                    elif iohEventStart >= previousEvent[0] and iohEventStart < previousEvent[1]:
                        overlapFound = True
                    # case 3: event occurs entirely inside of the window
                    elif predictiveSegmentStart < previousEvent[0] and iohEventStart > previousEvent[1]:
                        overlapFound = True
                    # do not extract a case if we overlap witha nother IOH
                    if overlapFound:
                        if debug:
                            print(f"Checking event {iohEvents[i]} for pred {predWindow} - exit, overlap with earlier segment")
                        continue

                # track the positive segment
                positiveSegments.append((predictiveSegmentStart, predictiveSegmentEnd, predWindow,
                    abp[predictiveSegmentStart*ABP_ECG_SRATE_HZ:predictiveSegmentEnd*ABP_ECG_SRATE_HZ],
                    ecg[predictiveSegmentStart*ABP_ECG_SRATE_HZ:predictiveSegmentEnd*ABP_ECG_SRATE_HZ],
                    eeg[predictiveSegmentStart*EEG_SRATE_HZ:predictiveSegmentEnd*EEG_SRATE_HZ]))

        # FOURTH PASS
        # in the fourth and final pass, we will use the collections of clean event windows to generate our actual extracted segments based (negative labels)
        for i in range(0, len(cleanEvents)):
            # everything will be 30 minutes long at least
            event = cleanEvents[i]
            # choose sample 1 @ 10 minutes
            # choose sample 2 @ 15 minutes
            # choose sample 3 @ 20 minutes
            timeAtTen = event[0] + 600
            timeAtFifteen = event[0] + 900
            timeAtTwenty = event[0] + 1200

            negativeSegments.append((timeAtTen, timeAtTen + 60, 0,
                                   abp[timeAtTen*ABP_ECG_SRATE_HZ:(timeAtTen + 60)*ABP_ECG_SRATE_HZ],
                                   ecg[timeAtTen*ABP_ECG_SRATE_HZ:(timeAtTen + 60)*ABP_ECG_SRATE_HZ],
                                   eeg[timeAtTen*EEG_SRATE_HZ:(timeAtTen + 60)*EEG_SRATE_HZ]))
            negativeSegments.append((timeAtFifteen, timeAtFifteen + 60, 0,
                                   abp[timeAtFifteen*ABP_ECG_SRATE_HZ:(timeAtFifteen + 60)*ABP_ECG_SRATE_HZ],
                                   ecg[timeAtFifteen*ABP_ECG_SRATE_HZ:(timeAtFifteen + 60)*ABP_ECG_SRATE_HZ],
                                   eeg[timeAtFifteen*EEG_SRATE_HZ:(timeAtFifteen + 60)*EEG_SRATE_HZ]))
            negativeSegments.append((timeAtTwenty, timeAtTwenty + 60, 0,
                                   abp[timeAtTwenty*ABP_ECG_SRATE_HZ:(timeAtTwenty + 60)*ABP_ECG_SRATE_HZ],
                                   ecg[timeAtTwenty*ABP_ECG_SRATE_HZ:(timeAtTwenty + 60)*ABP_ECG_SRATE_HZ],
                                   eeg[timeAtTwenty*EEG_SRATE_HZ:(timeAtTwenty + 60)*EEG_SRATE_HZ]))

        if returnSegments:
            positiveSegmentsMap[caseid] = positiveSegments
            negativeSegmentsMap[caseid] = negativeSegments
            iohEventsMap[caseid] = iohEvents
            cleanEventsMap[caseid] = cleanEvents
        
        saveCaseSegments(caseid, positiveSegments, negativeSegments, 9, debug=debug, forceWrite=forceWrite)

        #if debug:
        print(f'{caseid}: positiveSegments: {len(positiveSegments)}, negativeSegments: {len(negativeSegments)}')

    return positiveSegmentsMap, negativeSegmentsMap, iohEventsMap, cleanEventsMap

Case Extraction - Generage Segments Needed For Training¶

Ensure that all needed segments are in place for the cases that are being used. If data is already stored on disk this method returns immediately.

In [38]:
print('here')
here
In [39]:
MANUAL_EXTRACT=True

if MANUAL_EXTRACT:
    mycoi = cases_of_interest_idx
    #mycoi = cases_of_interest_idx[:2800]
    #mycoi = [1]

    cnt = 0
    mod = 0
    for ci in mycoi:
        cnt += 1
        if mod % 100 == 0:
            print(f'count processed: {mod}, current case index: {ci}')
        try:
            p, n, i, c = extract_segments([ci], debug=False, checkCache=True, forceWrite=True, returnSegments=False)
            p = None
            n = None
            i = None
            c = None
        except:
            print(f'error on extract segment: {ci}')
        mod += 1
    print(f'extracted: {cnt}')
count processed: 0, current case index: 1
count processed: 100, current case index: 198
count processed: 200, current case index: 431
count processed: 300, current case index: 665
724: exit early, no segments to save
724: positiveSegments: 0, negativeSegments: 0
818: exit early, no segments to save
818: positiveSegments: 0, negativeSegments: 0
count processed: 400, current case index: 853
count processed: 500, current case index: 1046
count processed: 600, current case index: 1236
1271: exit early, no segments to save
1271: positiveSegments: 0, negativeSegments: 0
count processed: 700, current case index: 1440
1505: exit early, no segments to save
1505: positiveSegments: 0, negativeSegments: 0
count processed: 800, current case index: 1639
count processed: 900, current case index: 1843
count processed: 1000, current case index: 2049
2218: exit early, no segments to save
2218: positiveSegments: 0, negativeSegments: 0
count processed: 1100, current case index: 2281
count processed: 1200, current case index: 2469
count processed: 1300, current case index: 2665
count processed: 1400, current case index: 2888
count processed: 1500, current case index: 3092
count processed: 1600, current case index: 3279
3413: exit early, no segments to save
3413: positiveSegments: 0, negativeSegments: 0
count processed: 1700, current case index: 3475
3476: exit early, no segments to save
3476: positiveSegments: 0, negativeSegments: 0
3533: exit early, no segments to save
3533: positiveSegments: 0, negativeSegments: 0
count processed: 1800, current case index: 3694
count processed: 1900, current case index: 3887
3992: exit early, no segments to save
3992: positiveSegments: 0, negativeSegments: 0
count processed: 2000, current case index: 4091
4187: nothing saved, all segments filtered
4187: positiveSegments: 0, negativeSegments: 18
count processed: 2100, current case index: 4296
4328: exit early, no segments to save
4328: positiveSegments: 0, negativeSegments: 0
count processed: 2200, current case index: 4509
4648: exit early, no segments to save
4648: positiveSegments: 0, negativeSegments: 0
4703: exit early, no segments to save
4703: positiveSegments: 0, negativeSegments: 0
count processed: 2300, current case index: 4732
4733: exit early, no segments to save
4733: positiveSegments: 0, negativeSegments: 0
4834: nothing saved, all segments filtered
4834: positiveSegments: 3, negativeSegments: 0
4836: nothing saved, all segments filtered
4836: positiveSegments: 11, negativeSegments: 6
count processed: 2400, current case index: 4929
4985: nothing saved, all segments filtered
4985: positiveSegments: 1, negativeSegments: 0
5130: exit early, no segments to save
5130: positiveSegments: 0, negativeSegments: 0
count processed: 2500, current case index: 5142
5175: nothing saved, all segments filtered
5175: positiveSegments: 2, negativeSegments: 0
5327: nothing saved, all segments filtered
5327: positiveSegments: 4, negativeSegments: 12
count processed: 2600, current case index: 5346
5501: exit early, no segments to save
5501: positiveSegments: 0, negativeSegments: 0
count processed: 2700, current case index: 5564
5587: nothing saved, all segments filtered
5587: positiveSegments: 2, negativeSegments: 0
5693: exit early, no segments to save
5693: positiveSegments: 0, negativeSegments: 0
count processed: 2800, current case index: 5771
5908: exit early, no segments to save
5908: positiveSegments: 0, negativeSegments: 0
count processed: 2900, current case index: 5974
6131: nothing saved, all segments filtered
6131: positiveSegments: 2, negativeSegments: 0
count processed: 3000, current case index: 6174
count processed: 3100, current case index: 6372
extracted: 3110

Track and Segment Validity Checks¶

In [40]:
def printAbp(case_id_to_check, plot_invalid_only=False):
        vf_path = f'{VITAL_MINI}/{case_id_to_check:04d}_mini.vital'
        vf = vitaldb.VitalFile(vf_path)
        abp = vf.to_numpy(TRACK_NAMES[0], 1/500)
        
        print(f'Case {case_id_to_check}')
        print(f'ABP Shape: {abp.shape}')

        print(f'nanmin: {np.nanmin(abp)}')
        print(f'nanmean: {np.nanmean(abp)}')
        print(f'nanmax: {np.nanmax(abp)}')
        
        is_valid = isAbpSegmentValidNumpy(abp, debug=True)
        print(f'valid: {is_valid}')

        if plot_invalid_only and is_valid:
            return
            
        plt.figure(figsize=(20, 5))
        plt_color = 'C0' if is_valid else 'red'
        plt.plot(abp, plt_color)
        plt.title(f'ABP - Entire Track - Case {case_id_to_check} - {abp.shape[0] / 500} seconds')
        plt.axhline(y = 65, color = 'maroon', linestyle = '--')
        plt.show()
In [41]:
def printSegments(segmentsMap, case_id_to_check, print_label, normalize=False):
    for (x1, x2, r, abp, ecg, eeg) in segmentsMap[case_id_to_check]:
        print(f'{print_label}: Case {case_id_to_check}')
        print(f'lookback window: {r} min')
        print(f'start time: {x1}')
        print(f'end time: {x2}')
        print(f'length: {x2 - x1} sec')
        
        print(f'ABP Shape: {abp.shape}')
        print(f'ECG Shape: {ecg.shape}')
        print(f'EEG Shape: {eeg.shape}')

        print(f'nanmin: {np.nanmin(abp)}')
        print(f'nanmean: {np.nanmean(abp)}')
        print(f'nanmax: {np.nanmax(abp)}')
        
        is_valid = isAbpSegmentValidNumpy(abp, debug=True)
        print(f'valid: {is_valid}')

        # ABP normalization
        x_abp = np.copy(abp)
        if normalize:
            x_abp -= 65
            x_abp /= 65

        plt.figure(figsize=(20, 5))
        plt_color = 'C0' if is_valid else 'red'
        plt.plot(x_abp, plt_color)
        plt.title('ABP')
        plt.axhline(y = 65, color = 'maroon', linestyle = '--')
        plt.show()

        plt.figure(figsize=(20, 5))
        plt.plot(ecg, 'teal')
        plt.title('ECG')
        plt.show()

        plt.figure(figsize=(20, 5))
        plt.plot(eeg, 'indigo')
        plt.title('EEG')
        plt.show()

        print()
In [42]:
def printEvents(abp_raw, eventsMap, case_id_to_check, print_label, normalize=False):
    for (x1, x2) in eventsMap[case_id_to_check]:
        print(f'{print_label}: Case {case_id_to_check}')
        print(f'start time: {x1}')
        print(f'end time: {x2}')
        print(f'length: {x2 - x1} sec')

        abp = abp_raw[x1*500:x2*500]
        print(f'ABP Shape: {abp.shape}')

        print(f'nanmin: {np.nanmin(abp)}')
        print(f'nanmean: {np.nanmean(abp)}')
        print(f'nanmax: {np.nanmax(abp)}')
        
        is_valid = isAbpSegmentValidNumpy(abp, debug=True)
        print(f'valid: {is_valid}')

        # ABP normalization
        x_abp = np.copy(abp)
        if normalize:
            x_abp -= 65
            x_abp /= 65

        plt.figure(figsize=(20, 5))
        plt_color = 'C0' if is_valid else 'red'
        plt.plot(x_abp, plt_color)
        plt.title('ABP')
        plt.axhline(y = 65, color = 'maroon', linestyle = '--')
        plt.show()

        print()

Reality Check All Cases¶

In [43]:
# Check if all ABPs are well formed.
DISPLAY_REALITY_CHECK_ABP=True
DISPLAY_REALITY_CHECK_ABP_FIRST_ONLY=True

if DISPLAY_REALITY_CHECK_ABP:
    for case_id_to_check in cases_of_interest_idx:
        printAbp(case_id_to_check, plot_invalid_only=False)
        
        if DISPLAY_REALITY_CHECK_ABP_FIRST_ONLY:
            break
Case 1
ABP Shape: (5770575, 1)
nanmin: -495.6260070800781
nanmean: 78.15251159667969
nanmax: 374.3236389160156
Presence of BP > 200
valid: False

Validate Malformed Vital Files - Missing One Or More Tracks¶

In [44]:
# These are Vital Files removed because of malformed ABP waveforms.
DISPLAY_MALFORMED_ABP=True
DISPLAY_MALFORMED_ABP_FIRST_ONLY=True

if DISPLAY_MALFORMED_ABP:
    malformed_case_ids = pd.read_csv('malformed_tracks_filter.csv', header=None, names=['caseid']).set_index('caseid').index

    for case_id_to_check in malformed_case_ids:
        printAbp(case_id_to_check)
        
        if DISPLAY_MALFORMED_ABP_FIRST_ONLY:
            break
Case 3
ABP Shape: (2196524, 1)
nanmin: -117.43000030517578
nanmean: 0.6060270667076111
nanmax: 85.98619842529297
Presence of BP < 30
valid: False

Validate Cases With No Segments Saved¶

In [45]:
DISPLAY_NO_SEGMENTS_CASES=True
DISPLAY_NO_SEGMENTS_CASES_FIRST_ONLY=True

if DISPLAY_NO_SEGMENTS_CASES:
    no_segments_case_ids = [3413, 3476, 3533, 3992, 4328, 4648, 4703, 4733, 5130, 5501, 5693, 5908]

    for case_id_to_check in no_segments_case_ids:
        printAbp(case_id_to_check)
        
        if DISPLAY_NO_SEGMENTS_CASES_FIRST_ONLY:
            break
Case 3413
ABP Shape: (3429927, 1)
nanmin: -228.025146484375
nanmean: 48.4425163269043
nanmax: 293.3521423339844
>10% NaN
valid: False

Select Case For Segment Extraction Validation¶

Generate segment data for one or more cases.

In [46]:
#mycoi = cases_of_interest_idx
mycoi = cases_of_interest_idx[:1]
#mycoi = [1]

positiveSegmentsMap, negativeSegmentsMap, iohEventsMap, cleanEventsMap = \
    extract_segments(mycoi, debug=False, checkCache=False, forceWrite=False, returnSegments=True)
1: positiveSegments: 12, negativeSegments: 9

Select a specific case to check.

In [47]:
case_id_to_check = cases_of_interest_idx[0]
#case_id_to_check = 1

print(case_id_to_check)
1
In [48]:
print((
    len(positiveSegmentsMap[case_id_to_check]),
    len(negativeSegmentsMap[case_id_to_check]),
    len(iohEventsMap[case_id_to_check]),
    len(cleanEventsMap[case_id_to_check])
))
(12, 9, 7, 3)
In [49]:
printAbp(case_id_to_check)
Case 1
ABP Shape: (5770575, 1)
nanmin: -495.6260070800781
nanmean: 78.15251159667969
nanmax: 374.3236389160156
Presence of BP > 200
valid: False

Positive Segments for Case - IOH Events¶

In [50]:
printSegments(positiveSegmentsMap, case_id_to_check, 'Positive Segment - IOH Event', normalize=False)
Positive Segment - IOH Event: Case 1
lookback window: 3 min
start time: 1548
end time: 1608
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 46.487884521484375
nanmean: 73.00869750976562
nanmax: 113.63497924804688
valid: True
Positive Segment - IOH Event: Case 1
lookback window: 5 min
start time: 1428
end time: 1488
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 41.550628662109375
nanmean: 74.47395324707031
nanmax: 128.44686889648438
valid: True
Positive Segment - IOH Event: Case 1
lookback window: 10 min
start time: 1128
end time: 1188
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 53.400115966796875
nanmean: 88.63211059570312
nanmax: 135.35903930664062
valid: True
Positive Segment - IOH Event: Case 1
lookback window: 15 min
start time: 828
end time: 888
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 23.776397705078125
nanmean: 108.88127136230469
nanmax: 182.75698852539062
Presence of BP < 30
valid: False
Positive Segment - IOH Event: Case 1
lookback window: 3 min
start time: 3873
end time: 3933
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 46.487884521484375
nanmean: 75.3544692993164
nanmax: 124.49703979492188
valid: True
Positive Segment - IOH Event: Case 1
lookback window: 5 min
start time: 3753
end time: 3813
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 45.500457763671875
nanmean: 73.97709655761719
nanmax: 122.52212524414062
valid: True
Positive Segment - IOH Event: Case 1
lookback window: 10 min
start time: 3453
end time: 3513
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 52.412628173828125
nanmean: 86.52787780761719
nanmax: 148.19595336914062
valid: True
Positive Segment - IOH Event: Case 1
lookback window: 15 min
start time: 3153
end time: 3213
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 58.337371826171875
nanmean: 100.94121551513672
nanmax: 165.97018432617188
valid: True
Positive Segment - IOH Event: Case 1
lookback window: 3 min
start time: 8856
end time: 8916
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 64.26211547851562
nanmean: 97.06536102294922
nanmax: 157.08309936523438
valid: True
Positive Segment - IOH Event: Case 1
lookback window: 5 min
start time: 8736
end time: 8796
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 69.19943237304688
nanmean: 105.55238342285156
nanmax: 163.00784301757812
valid: True
Positive Segment - IOH Event: Case 1
lookback window: 10 min
start time: 8436
end time: 8496
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: -88.793701171875
nanmean: 130.62982177734375
nanmax: 305.2016296386719
Presence of BP > 200
valid: False
Positive Segment - IOH Event: Case 1
lookback window: 15 min
start time: 8136
end time: 8196
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 62.287200927734375
nanmean: 92.04357147216797
nanmax: 138.32138061523438
valid: True

Negative Segments for Case - Non Events¶

In [51]:
printSegments(negativeSegmentsMap, case_id_to_check, 'Negative Segment - Non-Event', normalize=False)
Negative Segment - Non-Event: Case 1
lookback window: 0 min
start time: 5951
end time: 6011
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 52.412628173828125
nanmean: 76.35643005371094
nanmax: 120.54721069335938
valid: True
Negative Segment - Non-Event: Case 1
lookback window: 0 min
start time: 6251
end time: 6311
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 54.387542724609375
nanmean: 77.73150634765625
nanmax: 120.54721069335938
valid: True
Negative Segment - Non-Event: Case 1
lookback window: 0 min
start time: 6551
end time: 6611
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 58.337371826171875
nanmean: 85.06976318359375
nanmax: 133.38412475585938
valid: True
Negative Segment - Non-Event: Case 1
lookback window: 0 min
start time: 7752
end time: 7812
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 55.375030517578125
nanmean: 80.11844635009766
nanmax: 130.42178344726562
valid: True
Negative Segment - Non-Event: Case 1
lookback window: 0 min
start time: 8052
end time: 8112
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 60.312286376953125
nanmean: 88.32589721679688
nanmax: 134.37161254882812
valid: True
Negative Segment - Non-Event: Case 1
lookback window: 0 min
start time: 8352
end time: 8412
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 68.21194458007812
nanmean: 182.59963989257812
nanmax: 368.3988952636719
Presence of BP > 200
valid: False
Negative Segment - Non-Event: Case 1
lookback window: 0 min
start time: 10104
end time: 10164
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 48.462799072265625
nanmean: 72.81173706054688
nanmax: 115.60989379882812
valid: True
Negative Segment - Non-Event: Case 1
lookback window: 0 min
start time: 10404
end time: 10464
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: -7.822235107421875
nanmean: 106.73753356933594
nanmax: 236.07968139648438
Presence of BP > 200
valid: False
Negative Segment - Non-Event: Case 1
lookback window: 0 min
start time: 10704
end time: 10764
length: 60 sec
ABP Shape: (30000,)
ECG Shape: (30000,)
EEG Shape: (7680,)
nanmin: 110.67263793945312
nanmean: 172.22396850585938
nanmax: 239.04202270507812
Presence of BP > 200
valid: False

IOH Event Segments for Case - Positive Segments Identified From These¶

In [52]:
tmp_vf_path = f'{VITAL_MINI}/{case_id_to_check:04d}_mini.vital'
tmp_vf = vitaldb.VitalFile(tmp_vf_path)
tmp_abp = tmp_vf.to_numpy(TRACK_NAMES[0], 1/500)
In [53]:
printEvents(tmp_abp, iohEventsMap, case_id_to_check, 'IOH Event Segment', normalize=False)
IOH Event Segment: Case 1
start time: 1788
end time: 1849
length: 61 sec
ABP Shape: (30500, 1)
nanmin: 32.663482666015625
nanmean: 64.93988037109375
nanmax: 123.50955200195312
valid: True
IOH Event Segment: Case 1
start time: 1850
end time: 2113
length: 263 sec
ABP Shape: (131500, 1)
nanmin: 37.600799560546875
nanmean: 63.139060974121094
nanmax: 101.78549194335938
valid: True
IOH Event Segment: Case 1
start time: 2314
end time: 2375
length: 61 sec
ABP Shape: (30500, 1)
nanmin: -262.5861511230469
nanmean: 65.14369201660156
nanmax: 343.7124938964844
Presence of BP > 200
valid: False
IOH Event Segment: Case 1
start time: 4113
end time: 4199
length: 86 sec
ABP Shape: (43000, 1)
nanmin: 22.788909912109375
nanmean: 65.0725326538086
nanmax: 153.13327026367188
Presence of BP < 30
valid: False
IOH Event Segment: Case 1
start time: 4261
end time: 5350
length: 1089 sec
ABP Shape: (544500, 1)
nanmin: 36.613311767578125
nanmean: 60.451026916503906
nanmax: 110.67263793945312
valid: True
IOH Event Segment: Case 1
start time: 9096
end time: 9156
length: 60 sec
ABP Shape: (30000, 1)
nanmin: 40.563140869140625
nanmean: 64.9837646484375
nanmax: 108.69772338867188
valid: True
IOH Event Segment: Case 1
start time: 9157
end time: 9503
length: 346 sec
ABP Shape: (173000, 1)
nanmin: 39.575714111328125
nanmean: 62.33021545410156
nanmax: 104.74789428710938
valid: True

Clean Event Segments for Case - Negative Segments Identified From These¶

In [54]:
printEvents(tmp_abp, cleanEventsMap, case_id_to_check, 'Clean Event Segment', normalize=False)
Clean Event Segment: Case 1
start time: 5351
end time: 7151
length: 1800 sec
ABP Shape: (900000, 1)
nanmin: 40.563140869140625
nanmean: 84.04818725585938
nanmax: 151.15835571289062
valid: True
Clean Event Segment: Case 1
start time: 7152
end time: 8952
length: 1800 sec
ABP Shape: (900000, 1)
nanmin: -495.6260070800781
nanmean: 99.71124267578125
nanmax: 368.3988952636719
Presence of BP > 200
valid: False
Clean Event Segment: Case 1
start time: 9504
end time: 11304
length: 1800 sec
ABP Shape: (900000, 1)
nanmin: -49.295440673828125
nanmean: 83.3201675415039
nanmax: 346.6748352050781
Presence of BP > 200
valid: False

In [55]:
# free memory
tmp_abp = None

Generate Train/Val/Test Splits¶

In [56]:
def get_segment_attributes_from_filename(file_path):
    pieces = os.path.basename(file_path).split('_')
    case = int(pieces[0])
    startX = int(pieces[1])
    predWindow = int(pieces[2])
    label = pieces[3].replace('.h5', '')
    return (case, startX, predWindow, label)
In [57]:
count_negative_samples = 0
count_positive_samples = 0

samples = []

from glob import glob
seg_folder = f"{VITAL_EXTRACTED_SEGMENTS}"
filenames = [y for x in os.walk(seg_folder) for y in glob(os.path.join(x[0], '*.h5'))]

for filename in filenames:
    (case, start_x, pred_window, label) = get_segment_attributes_from_filename(filename)
    #print((case, start_x, pred_window, label))
    
    # only load cases for cases of interest; this folder could have segments for hundreds of cases
    if case not in cases_of_interest_idx:
        continue

    #PREDICTION_WINDOW = 3
    if pred_window == 0 or pred_window == PREDICTION_WINDOW or PREDICTION_WINDOW == 'ALL':
        #print((case, start_x, pred_window, label))
        if label == 'True':
            count_positive_samples += 1
        else:
            count_negative_samples += 1
        sample = (filename, label)
        samples.append(sample)

print()
print(f"samples loaded:         {len(samples):5} ")
print(f'count negative samples: {count_negative_samples:5}')
print(f'count positive samples: {count_positive_samples:5}')
samples loaded:         42869 
count negative samples: 37572
count positive samples:  5297
In [58]:
# Divide by cases
sample_cases = defaultdict(lambda: []) 

for fn, _ in samples:
    (case, start_x, pred_window, label) = get_segment_attributes_from_filename(fn)
    sample_cases[case].append((fn, label))

# understand any missing cases of interest
sample_cases_idx = pd.Index(sample_cases.keys())
missing_case_ids = cases_of_interest_idx.difference(sample_cases_idx)
print(f'cases with no samples: {missing_case_ids.shape[0]}')
print(f'    {missing_case_ids}')
print()
    
# Split data into training, validation, and test sets
# Use 6:1:3 ratio and prevent samples from a single case from being split across different sets
# Note: number of samples at each time point is not the same, because the first event can occur before the 3/5/10/15 minute mark

# Set target sizes
train_ratio = 0.6
val_ratio = 0.1
test_ratio = 1 - train_ratio - val_ratio # ensure ratios sum to 1

# Split samples into train and other
sample_cases_train, sample_cases_other = train_test_split(list(sample_cases.keys()), test_size=(1 - train_ratio), random_state=RANDOM_SEED)

# Split other into val and test
sample_cases_val, sample_cases_test = train_test_split(sample_cases_other, test_size=(test_ratio / (1 - train_ratio)), random_state=RANDOM_SEED)

# Check how many samples are in each set
print(f'Train/Val/Test Summary by Cases')
print(f"Train cases:  {len(sample_cases_train):5}, ({len(sample_cases_train) / len(sample_cases):.2%})")
print(f"Val cases:    {len(sample_cases_val):5}, ({len(sample_cases_val) / len(sample_cases):.2%})")
print(f"Test cases:   {len(sample_cases_test):5}, ({len(sample_cases_test) / len(sample_cases):.2%})")
print(f"Total cases:  {(len(sample_cases_train) + len(sample_cases_val) + len(sample_cases_test)):5}")
cases with no samples: 74
    Index([  92,  270,  387,  431,  455,  724,  818, 1015, 1056, 1133, 1143, 1208,
       1209, 1271, 1297, 1503, 1505, 1632, 1812, 1915, 2012, 2081, 2097, 2111,
       2132, 2218, 2251, 2605, 2803, 2806, 3003, 3153, 3198, 3336, 3355, 3396,
       3412, 3413, 3476, 3533, 3736, 3992, 4016, 4109, 4187, 4328, 4485, 4648,
       4703, 4733, 4834, 4836, 4971, 4985, 5061, 5117, 5130, 5142, 5157, 5175,
       5327, 5460, 5500, 5501, 5587, 5693, 5908, 5917, 5945, 5982, 6131, 6271,
       6315, 6331],
      dtype='int64')

Train/Val/Test Summary by Cases
Train cases:   1821, (59.98%)
Val cases:      303, (9.98%)
Test cases:     912, (30.04%)
Total cases:   3036
In [59]:
sample_cases_train = set(sample_cases_train)
sample_cases_val = set(sample_cases_val)
sample_cases_test = set(sample_cases_test)

samples_train = []
samples_val = []
samples_test = []

for cid, segs in sample_cases.items():
    if cid in sample_cases_train:
        for seg in segs:
            samples_train.append(seg)
    if cid in sample_cases_val:
        for seg in segs:
            samples_val.append(seg)
    if cid in sample_cases_test:
        for seg in segs:
            samples_test.append(seg)
            
# Check how many samples are in each set
print(f'Train/Val/Test Summary by Events')
print(f"Train events:  {len(samples_train):5}, ({len(samples_train) / len(samples):.2%})")
print(f"Val events:    {len(samples_val):5}, ({len(samples_val) / len(samples):.2%})")
print(f"Test events:   {len(samples_test):5}, ({len(samples_test) / len(samples):.2%})")
print(f"Total events:  {(len(samples_train) + len(samples_val) + len(samples_test)):5}")
Train/Val/Test Summary by Events
Train events:  25433, (59.33%)
Val events:     4407, (10.28%)
Test events:   13029, (30.39%)
Total events:  42869

Validate train/val/test Splits¶

In [60]:
PRINT_ALL_CASE_SPLIT_DETAILS = False

case_to_sample_distribution = defaultdict(lambda: {'train': [0, 0], 'val': [0, 0], 'test': [0, 0]})

def populate_case_to_sample_distribution(mysamples, idx):
    neg = 0
    pos = 0
    
    for fn, _ in mysamples:
        (case, start_x, pred_window, label) = get_segment_attributes_from_filename(fn)
        slot = 0 if label == 'False' else 1
        case_to_sample_distribution[case][idx][slot] += 1
        if slot == 0:
            neg += 1
        else:
            pos += 1
                
    return (neg, pos)

train_neg, train_pos = populate_case_to_sample_distribution(samples_train, 'train')
val_neg, val_pos     = populate_case_to_sample_distribution(samples_val,   'val')
test_neg, test_pos   = populate_case_to_sample_distribution(samples_test,  'test')

print(f'Total Cases Present: {len(case_to_sample_distribution):5}')
print()

train_tot = train_pos + train_neg
val_tot = val_pos + val_neg
test_tot = test_pos + test_neg
print(f'Train: P: {train_pos:5} ({(train_pos/train_tot):.2}), N: {train_neg:5} ({(train_neg/train_tot):.2})')
print(f'Val:   P: {val_pos:5} ({(val_pos/val_tot):.2}), N: {val_neg:5} ({(val_neg/val_tot):.2})')
print(f'Test:  P: {test_pos:5} ({(test_pos/test_tot):.2}), N: {test_neg:5}  ({(test_neg/test_tot):.2})')
print()

total_pos = train_pos + val_pos + test_pos
total_neg = train_neg + val_neg + test_neg
total = total_pos + total_neg
print(f'P/N Ratio: {(total_pos)}:{(total_neg)}')
print(f'P Percent: {(total_pos/total):.2}')
print(f'N Percent: {(total_neg/total):.2}')
print()

if PRINT_ALL_CASE_SPLIT_DETAILS:
    for ci in sorted(case_to_sample_distribution.keys()):
        print(f'{ci}: {case_to_sample_distribution[ci]}')
Total Cases Present:  3036

Train: P:  3197 (0.13), N: 22236 (0.87)
Val:   P:   487 (0.11), N:  3920 (0.89)
Test:  P:  1613 (0.12), N: 11416  (0.88)

P/N Ratio: 5297:37572
P Percent: 0.12
N Percent: 0.88

In [61]:
# Create vitalDataset class
class vitalDataset(Dataset):
    def __init__(self, samples, normalize_abp=False):
        self.samples = samples
        self.normalize_abp = normalize_abp

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Get metadata for this event
        segment = self.samples[idx]

        file_path = segment[0]
        label = (segment[1] == "True" or segment[1] == "True.vital")

        (abp, ecg, eeg) = get_segment_data(file_path)

        if abp is None or eeg is None or ecg is None:
            return (np.zeros(30000), np.zeros(30000), np.zeros(7680), 0)
        
        if self.normalize_abp:
            abp -= 65
            abp /= 65

        return abp, ecg, eeg, label
In [62]:
NORMALIZE_ABP = False

train_dataset = vitalDataset(samples_train, NORMALIZE_ABP)
val_dataset = vitalDataset(samples_val, NORMALIZE_ABP)
test_dataset = vitalDataset(samples_test, NORMALIZE_ABP)

BATCH_SIZE = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Classification Studies¶

Check if data can be easily classified using non-deep learning methods. Create a balanced sample of IOH and non-IOH events and use a simple classifier to see if the data can be easily separated. Datasets which can be easily separated by non-deep learning methods should also be easily classified by deep learning models.

In [63]:
MAX_CLASSIFICATION_SAMPLES = 250
MAX_SAMPLE_SIZE = 1600
classification_sample_size = MAX_SAMPLE_SIZE if len(samples) >= MAX_SAMPLE_SIZE else len(samples)

classification_samples = random.sample(samples, classification_sample_size)

positive_samples = []
negative_samples = []

for sample in classification_samples:
    (sampleAbp, sampleEcg, sampleEeg) = get_segment_data(sample[0])
    
    if sample[1] == "True":
        positive_samples.append([sample[0], True, sampleAbp, sampleEcg, sampleEeg])
    else:
        negative_samples.append([sample[0], False, sampleAbp, sampleEcg, sampleEeg])

positive_samples = pd.DataFrame(positive_samples, columns=["file_path", "segment_label", "segment_abp", "segment_ecg", "segment_eeg"])
negative_samples = pd.DataFrame(negative_samples, columns=["file_path", "segment_label", "segment_abp", "segment_ecg", "segment_eeg"])

total_to_sample_pos = MAX_CLASSIFICATION_SAMPLES if len(positive_samples) >= MAX_CLASSIFICATION_SAMPLES else len(positive_samples)
total_to_sample_neg = MAX_CLASSIFICATION_SAMPLES if len(negative_samples) >= MAX_CLASSIFICATION_SAMPLES else len(negative_samples)

# Select up to 150 random samples where segment_label is True
positive_samples = positive_samples.sample(total_to_sample_pos, random_state=RANDOM_SEED)
# Select up to 150 random samples where segment_label is False
negative_samples = negative_samples.sample(total_to_sample_neg, random_state=RANDOM_SEED)

print(f'positive_samples: {len(positive_samples)}')
print(f'negative_samples: {len(negative_samples)}')

# Combine the positive and negative samples
samples_balanced = pd.concat([positive_samples, negative_samples])
positive_samples: 183
negative_samples: 250

Define function to build data for study. Each waveform field can be enabled or disabled:

In [64]:
def get_x_y(samples, use_abp, use_ecg, use_eeg):
    # Create X and y, using data from `samples_balanced` and the `use_abp`, `use_ecg`, and `use_eeg` variables
    X = []
    y = []
    for i in range(len(samples)):
        row = samples.iloc[i]
        sample = np.array([])
        if use_abp:
            if len(row['segment_abp']) != 30000:
                print(len(row['segment_abp']))
            sample = np.append(sample, row['segment_abp'])
        if use_ecg:
            if len(row['segment_ecg']) != 30000:
                print(len(row['segment_ecg']))
            sample = np.append(sample, row['segment_ecg'])
        if use_eeg:
            if len(row['segment_eeg']) != 7680:
                print(len(row['segment_eeg']))
            sample = np.append(sample, row['segment_eeg'])
        X.append(sample)
        # Convert the label from boolean to 0 or 1
        y.append(int(row['segment_label']))
    return X, y

KNN¶

Define KNN run. This is configurable to enable or disable different data channels so that we can study them individually or together:

In [65]:
N_NEIGHBORS = 20

def run_knn(samples, use_abp, use_ecg, use_eeg):
    # Get samples
    X,y = get_x_y(samples, use_abp, use_ecg, use_eeg)

    # Split samples into train and val
    knn_X_train, knn_X_test, knn_y_train, knn_y_test = train_test_split(X, y, test_size=0.2, random_state=RANDOM_SEED)

    # Normalize the data
    scaler = StandardScaler()
    scaler.fit(knn_X_train)

    knn_X_train = scaler.transform(knn_X_train)
    knn_X_test = scaler.transform(knn_X_test)

    # Initialize the KNN classifier
    knn = KNeighborsClassifier(n_neighbors=N_NEIGHBORS)

    # Train the KNN classifier
    knn.fit(knn_X_train, knn_y_train)

    # Make predictions on the test set
    knn_y_pred = knn.predict(knn_X_test)

    # Evaluate the KNN classifier
    print(f"ABP: {use_abp}, ECG: {use_ecg}, EEG: {use_eeg}")
    print(f"Confusion matrix:\n{confusion_matrix(knn_y_test, knn_y_pred)}")
    print(f"Classification report:\n{classification_report(knn_y_test, knn_y_pred)}")

Study each waveform independently, then ABP+EEG (which had best results in paper), and ABP+ECG+EEG:

In [66]:
run_knn(samples_balanced, use_abp=True, use_ecg=False, use_eeg=False)
run_knn(samples_balanced, use_abp=False, use_ecg=True, use_eeg=False)
run_knn(samples_balanced, use_abp=False, use_ecg=False, use_eeg=True)
run_knn(samples_balanced, use_abp=True, use_ecg=False, use_eeg=True)
run_knn(samples_balanced, use_abp=True, use_ecg=True, use_eeg=True)
ABP: True, ECG: False, EEG: False
Confusion matrix:
[[42  0]
 [45  0]]
Classification report:
              precision    recall  f1-score   support

           0       0.48      1.00      0.65        42
           1       0.00      0.00      0.00        45

    accuracy                           0.48        87
   macro avg       0.24      0.50      0.33        87
weighted avg       0.23      0.48      0.31        87

/Users/sphillips/bin/anaconda3/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/Users/sphillips/bin/anaconda3/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
/Users/sphillips/bin/anaconda3/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1469: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, msg_start, len(result))
ABP: False, ECG: True, EEG: False
Confusion matrix:
[[42  0]
 [44  1]]
Classification report:
              precision    recall  f1-score   support

           0       0.49      1.00      0.66        42
           1       1.00      0.02      0.04        45

    accuracy                           0.49        87
   macro avg       0.74      0.51      0.35        87
weighted avg       0.75      0.49      0.34        87

ABP: False, ECG: False, EEG: True
Confusion matrix:
[[37  5]
 [39  6]]
Classification report:
              precision    recall  f1-score   support

           0       0.49      0.88      0.63        42
           1       0.55      0.13      0.21        45

    accuracy                           0.49        87
   macro avg       0.52      0.51      0.42        87
weighted avg       0.52      0.49      0.41        87

ABP: True, ECG: False, EEG: True
Confusion matrix:
[[39  3]
 [41  4]]
Classification report:
              precision    recall  f1-score   support

           0       0.49      0.93      0.64        42
           1       0.57      0.09      0.15        45

    accuracy                           0.49        87
   macro avg       0.53      0.51      0.40        87
weighted avg       0.53      0.49      0.39        87

ABP: True, ECG: True, EEG: True
Confusion matrix:
[[35  7]
 [26 19]]
Classification report:
              precision    recall  f1-score   support

           0       0.57      0.83      0.68        42
           1       0.73      0.42      0.54        45

    accuracy                           0.62        87
   macro avg       0.65      0.63      0.61        87
weighted avg       0.65      0.62      0.60        87

Based on the data above, the ABP data alone is strongly predictive based on the macro average F1-score of 0.90. The ECG and EEG data are weakly predictive with F1 scores of 0.33 and 0.64, respectively. The ABP+EEG data is also strongly predictive with an F1 score of 0.88, and ABP+ECG+EEG data somewhat predictive with an F1 score of 0.79.

Models based on ABP data alone, or ABP+EEG data are expected to train easily with good performance. The other signals appear to mostly add noise and are not strongly predictive. This agrees with the results from the paper.

t-SNE¶

Define t-SNE run. This is configurable to enable or disable different data channels so that we can study them individually or together:

In [67]:
def run_tsne(samples, use_abp, use_ecg, use_eeg):
    # Get samples
    X,y = get_x_y(samples, use_abp, use_ecg, use_eeg)
    
    # Convert X and y to numpy arrays
    X = np.array(X)
    y = np.array(y)

    # Run t-SNE on the samples
    tsne = TSNE(n_components=len(np.unique(y)), random_state=RANDOM_SEED)
    X_tsne = tsne.fit_transform(X)
    
    # Create a scatter plot of the t-SNE representation
    plt.figure(figsize=(16, 9))
    plt.title(f"use_abp={use_abp}, use_ecg={use_ecg}, use_eeg={use_eeg}")
    for i, label in enumerate(set(y)):
        plt.scatter(X_tsne[y == label, 0], X_tsne[y == label, 1], label=label)
    plt.legend()
    plt.show()

Study each waveform independently, then ABP+EEG (which had best results in paper), and ABP+ECG+EEG:

In [68]:
run_tsne(samples_balanced, use_abp=True, use_ecg=False, use_eeg=False)
run_tsne(samples_balanced, use_abp=False, use_ecg=True, use_eeg=False)
run_tsne(samples_balanced, use_abp=False, use_ecg=False, use_eeg=True)
run_tsne(samples_balanced, use_abp=True, use_ecg=False, use_eeg=True)
run_tsne(samples_balanced, use_abp=True, use_ecg=True, use_eeg=True)

Based on the plots above, it appears that ABP alone, ABP+EEG and ABP+ECG+EEG are somewhat separable, though with outliers, and should be trainable by our model. The ECG and EEG data are not easily separable from the other data. This agrees with the results from the paper.

In [69]:
# cleanup
samples_balanced = None

Model¶

The model implementation is based on the CNN architecture described in Jo Y-Y et al. (2022). It is designed to handle 1, 2, or 3 signal categories simultaneously, allowing for flexible model configurations based on different combinations of physiological signals:

  • ABP alone
  • EEG alone
  • ECG alone
  • ABP + EEG
  • ABP + ECG
  • EEG + ECG
  • ABP + EEG + ECG

Model Architecture¶

The architecture, as depicted in Figure 2 from the original paper, utilizes a ResNet-based approach tailored for time-series data from different physiological signals. The model architecture is adapted to handle varying input signal frequencies, with specific hyperparameters for each signal type, particularly EEG, due to its distinct characteristics compared to ABP and ECG. A diagram of the model architecture is shown below:

Architecture of the hypotension risk prediction model using multiple waveforms

Each input signal is processed through a sequence of 12 7-layer residual blocks, followed by a flattening process and a linear transformation to produce a 32-dimensional feature vector per signal type. These vectors are then concatenated (if multiple signals are used) and passed through two additional linear layers to produce a single output vector, representing the IOH index. A threshold is determined experimentally in order to minimize the differene between the sensitivity and specificity and is applied to this index to perform binary classification for predicting IOH events.

The hyperparameters for the residual blocks are specified in Supplemental Table 1 from the original paper and vary for different signal type.

A forward pass through the model passes through 85 layers before concatenation, followed by two more linear layers and finally a sigmoid activation layer to produce the prediction measure.

Residual Block Definition¶

Each residual block consists of the following seven layers:

  • Batch normalization
  • ReLU
  • Dropout (0.5)
  • 1D convolution
  • Batch normalization
  • ReLU
  • 1D convolution

Skip connections are included to aid in gradient flow during training, with optional 1D convolution in the skip connection to align dimensions.

Residual Block Hyperparameters¶

The hyperparameters are detailed in Supplemental Table 1 of the original paper. A screenshot of these hyperparameters is provided for reference below:

Supplemental Table 1 from original paper

Note: Please be aware of a transcription error in the original paper's Supplemental Table 1 for the ECG+ABP configuration in Residual Blocks 11 and 12, where the output size should be 469 6 instead of the reported 496 6.

Training Objectives¶

Our model uses binary cross entropy as the loss function and Adam as the optimizer, consistent with the original study. The learning rate is set at 0.0001, and training is configured to run for up to 100 epochs, with early stopping implemented if no improvement in loss is observed over five consecutive epochs.

In [70]:
# First define the residual block which is reused 12x for each data track for each sample.
# Second define the primary model.
class ResidualBlock(nn.Module):
    def __init__(self, in_features: int, out_features: int, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, size_down: bool = False, ignoreSkipConnection: bool = False) -> None:
        super(ResidualBlock, self).__init__()
        
        self.ignoreSkipConnection = ignoreSkipConnection

        # calculate the appropriate padding required to ensure expected sequence lengths out of each residual block
        padding = int((((stride-1)*in_features)-stride+kernel_size)/2)

        self.size_down = size_down
        self.bn1 = nn.BatchNorm1d(in_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)
        
        self.residualConv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, bias=False)

        # unclear where in sequence this should take place. Size down expressed in Supplemental table S1
        if self.size_down:
            pool_padding = (1 if (in_features % 2 > 0) else 0)
            self.downsample = nn.MaxPool1d(kernel_size=2, stride=2, padding = pool_padding)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x
        
        out = self.bn1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv1(out)

        if self.size_down:
            out = self.downsample(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        
        if not self.ignoreSkipConnection:
          if out.shape != identity.shape:
              # run the residual through a convolution when necessary
              identity = self.residualConv(identity)
            
              outlen = np.prod(out.shape)
              idlen = np.prod(identity.shape)
              # downsample when required
              if idlen > outlen:
                  identity = self.downsample(identity)
              # match dimensions
              identity = identity.reshape(out.shape)

          # add the residual       
          out += identity

        return  out

class HypotensionCNN(nn.Module):
    def __init__(self, useAbp: bool = True, useEeg: bool = False, useEcg: bool = False, maxSixResiduals: bool = False, maxOneResiduals: bool = False, ignoreSkipConnection: bool = False) -> None:
        super(HypotensionCNN, self).__init__()

        self.useAbp = useAbp
        self.useEeg = useEeg
        self.useEcg = useEcg
        self.maxSixResiduals = maxSixResiduals
        self.maxOneResiduals= maxOneResiduals


        if useAbp:
            if not self.maxOneResiduals and not self.maxSixResiduals:
              self.abpBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
              self.abpBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False, ignoreSkipConnection)
              self.abpBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True, ignoreSkipConnection)
              self.abpBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False, ignoreSkipConnection)
              self.abpBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True, ignoreSkipConnection)
              self.abpBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False, ignoreSkipConnection)
              self.abpBlock7 = ResidualBlock(3750, 1875, 4, 4, 7, 1, True, ignoreSkipConnection)
              self.abpBlock8 = ResidualBlock(1875, 1875, 4, 4, 7, 1, False, ignoreSkipConnection)
              self.abpBlock9 = ResidualBlock(1875, 938, 4, 4, 7, 1, True, ignoreSkipConnection)
              self.abpBlock10 = ResidualBlock(938, 938, 4, 4, 7, 1, False, ignoreSkipConnection)
              self.abpBlock11 = ResidualBlock(938, 469, 4, 6, 7, 1, True, ignoreSkipConnection)
              self.abpBlock12 = ResidualBlock(469, 469, 6, 6, 7, 1, False, ignoreSkipConnection)
              self.abpFc = nn.Linear(6*469, 32)
            elif self.maxOneResiduals:
              self.abpBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
              self.abpFc = nn.Linear(2 * 15000, 32)
            elif self.maxSixResiduals:
              self.abpBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
              self.abpBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False, ignoreSkipConnection)
              self.abpBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True, ignoreSkipConnection)
              self.abpBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False, ignoreSkipConnection)
              self.abpBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True, ignoreSkipConnection)
              self.abpBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False, ignoreSkipConnection)
              self.abpFc = nn.Linear(4 * 3750, 32)
              
            
        
        if useEcg:
            if not self.maxOneResiduals and not self.maxSixResiduals:
              self.ecgBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
              self.ecgBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False, ignoreSkipConnection)
              self.ecgBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True, ignoreSkipConnection)
              self.ecgBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False, ignoreSkipConnection)
              self.ecgBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True, ignoreSkipConnection)
              self.ecgBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False, ignoreSkipConnection)
              self.ecgBlock7 = ResidualBlock(3750, 1875, 4, 4, 7, 1, True, ignoreSkipConnection)
              self.ecgBlock8 = ResidualBlock(1875, 1875, 4, 4, 7, 1, False, ignoreSkipConnection)
              self.ecgBlock9 = ResidualBlock(1875, 938, 4, 4, 7, 1, True, ignoreSkipConnection)
              self.ecgBlock10 = ResidualBlock(938, 938, 4, 4, 7, 1, False, ignoreSkipConnection)
              self.ecgBlock11 = ResidualBlock(938, 469, 4, 6, 7, 1, True, ignoreSkipConnection)
              self.ecgBlock12 = ResidualBlock(469, 469, 6, 6, 7, 1, False, ignoreSkipConnection)
              self.ecgFc = nn.Linear(6 * 469, 32)
            elif self.maxOneResiduals:
              self.ecgBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
              self.ecgFc = nn.Linear(2 * 15000, 32)
            elif self.maxSixResiduals:
              self.ecgBlock1 = ResidualBlock(30000, 15000, 1, 2, 15, 1, True, ignoreSkipConnection)
              self.ecgBlock2 = ResidualBlock(15000, 15000, 2, 2, 15, 1, False, ignoreSkipConnection)
              self.ecgBlock3 = ResidualBlock(15000, 7500, 2, 2, 15, 1, True, ignoreSkipConnection)
              self.ecgBlock4 = ResidualBlock(7500, 7500, 2, 2, 15, 1, False, ignoreSkipConnection)
              self.ecgBlock5 = ResidualBlock(7500, 3750, 2, 2, 15, 1, True, ignoreSkipConnection)
              self.ecgBlock6 = ResidualBlock(3750, 3750, 2, 4, 15, 1, False, ignoreSkipConnection)
              self.ecgFc = nn.Linear(4 * 3750, 32)

        
        if useEeg:
          if not self.maxOneResiduals and not self.maxSixResiduals:
            self.eegBlock1 = ResidualBlock(7680, 3840, 1, 2, 7, 1, True, ignoreSkipConnection)
            self.eegBlock2 = ResidualBlock(3840, 3840, 2, 2, 7, 1, False, ignoreSkipConnection)
            self.eegBlock3 = ResidualBlock(3840, 1920, 2, 2, 7, 1, True, ignoreSkipConnection)
            self.eegBlock4 = ResidualBlock(1920, 1920, 2, 2, 7, 1, False, ignoreSkipConnection)
            self.eegBlock5 = ResidualBlock(1920, 960, 2, 2, 7, 1, True, ignoreSkipConnection)
            self.eegBlock6 = ResidualBlock(960, 960, 2, 4, 7, 1, False, ignoreSkipConnection)
            self.eegBlock7 = ResidualBlock(960, 480, 4, 4, 3, 1, True, ignoreSkipConnection)
            self.eegBlock8 = ResidualBlock(480, 480, 4, 4, 3, 1, False, ignoreSkipConnection)
            self.eegBlock9 = ResidualBlock(480, 240, 4, 4, 3, 1, True, ignoreSkipConnection)
            self.eegBlock10 = ResidualBlock(240, 240, 4, 4, 3, 1, False, ignoreSkipConnection)
            self.eegBlock11 = ResidualBlock(240, 120, 4, 6, 3, 1, True, ignoreSkipConnection)
            self.eegBlock12 = ResidualBlock(120, 120, 6, 6, 3, 1, False, ignoreSkipConnection)
            self.eegFc = nn.Linear(6 * 120, 32)
          elif self.maxOneResiduals:
            self.eegBlock1 = ResidualBlock(7680, 3840, 1, 2, 7, 1, True, ignoreSkipConnection)
            self.eegFc = nn.Linear(2 * 3840, 32)
          elif self.maxSixResiduals:
            self.eegBlock1 = ResidualBlock(7680, 3840, 1, 2, 7, 1, True, ignoreSkipConnection)
            self.eegBlock2 = ResidualBlock(3840, 3840, 2, 2, 7, 1, False, ignoreSkipConnection)
            self.eegBlock3 = ResidualBlock(3840, 1920, 2, 2, 7, 1, True, ignoreSkipConnection)
            self.eegBlock4 = ResidualBlock(1920, 1920, 2, 2, 7, 1, False, ignoreSkipConnection)
            self.eegBlock5 = ResidualBlock(1920, 960, 2, 2, 7, 1, True, ignoreSkipConnection)
            self.eegBlock6 = ResidualBlock(960, 960, 2, 4, 7, 1, False, ignoreSkipConnection)
            self.eegFc = nn.Linear(4 * 960, 32)


        concatSize = 0
        if useAbp:
            concatSize += 32
        if useEeg:
            concatSize += 32
        if useEcg:
            concatSize += 32

        self.fullLinear1 = nn.Linear(concatSize, 16)
        self.fullLinear2 = nn.Linear(16, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, abp: torch.Tensor, eeg: torch.Tensor, ecg: torch.Tensor) -> torch.Tensor:

        batchSize = len(abp)

        # conditionally operate ABP, EEG, and ECG networks
        if self.useAbp:
            if self.maxOneResiduals:
              abp = self.abpBlock1(abp)
            elif self.maxSixResiduals:
              abp = self.abpBlock1(abp)
              abp = self.abpBlock2(abp)
              abp = self.abpBlock3(abp)
              abp = self.abpBlock4(abp)
              abp = self.abpBlock5(abp)
              abp = self.abpBlock6(abp)
            elif not self.maxOneResiduals and not self.maxSixResiduals:
              abp = self.abpBlock1(abp)
              abp = self.abpBlock2(abp)
              abp = self.abpBlock3(abp)
              abp = self.abpBlock4(abp)
              abp = self.abpBlock5(abp)
              abp = self.abpBlock6(abp)
              abp = self.abpBlock7(abp)
              abp = self.abpBlock8(abp)
              abp = self.abpBlock9(abp)
              abp = self.abpBlock10(abp)
              abp = self.abpBlock11(abp)
              abp = self.abpBlock12(abp)
              
            totalLen = np.prod(abp.shape)
            abp = torch.reshape(abp, (batchSize, int(totalLen / batchSize)))
            abp = self.abpFc(abp)

        if self.useEeg:
            if self.maxOneResiduals:
              eeg = self.eegBlock1(eeg)
            elif self.maxSixResiduals:
              eeg = self.eegBlock1(eeg)
              eeg = self.eegBlock2(eeg)
              eeg = self.eegBlock3(eeg)
              eeg = self.eegBlock4(eeg)
              eeg = self.eegBlock5(eeg)
              eeg = self.eegBlock6(eeg)
            elif not self.maxOneResiduals and not self.maxSixResiduals:
              eeg = self.eegBlock1(eeg)
              eeg = self.eegBlock2(eeg)
              eeg = self.eegBlock3(eeg)
              eeg = self.eegBlock4(eeg)
              eeg = self.eegBlock5(eeg)
              eeg = self.eegBlock6(eeg)
              eeg = self.eegBlock7(eeg)
              eeg = self.eegBlock8(eeg)
              eeg = self.eegBlock9(eeg)
              eeg = self.eegBlock10(eeg)
              eeg = self.eegBlock11(eeg)
              eeg = self.eegBlock12(eeg)
            
            totalLen = np.prod(eeg.shape)
            eeg = torch.reshape(eeg, (batchSize, int(totalLen / batchSize)))
            eeg = self.eegFc(eeg)
        
        if self.useEcg:
            if self.maxOneResiduals:
              ecg = self.ecgBlock1(ecg)
            elif self.maxSixResiduals:
              ecg = self.ecgBlock1(ecg)
              ecg = self.ecgBlock2(ecg)
              ecg = self.ecgBlock3(ecg)
              ecg = self.ecgBlock4(ecg)
              ecg = self.ecgBlock5(ecg)
              ecg = self.ecgBlock6(ecg)
            elif not self.maxOneResiduals and not self.maxSixResiduals:
              ecg = self.ecgBlock1(ecg)
              ecg = self.ecgBlock2(ecg)
              ecg = self.ecgBlock3(ecg)
              ecg = self.ecgBlock4(ecg)
              ecg = self.ecgBlock5(ecg)
              ecg = self.ecgBlock6(ecg)
              ecg = self.ecgBlock7(ecg)
              ecg = self.ecgBlock8(ecg)
              ecg = self.ecgBlock9(ecg)
              ecg = self.ecgBlock10(ecg)
              ecg = self.ecgBlock11(ecg)
              ecg = self.ecgBlock12(ecg)

            totalLen = np.prod(ecg.shape)
            ecg = torch.reshape(ecg, (batchSize, int(totalLen / batchSize)))
            ecg = self.ecgFc(ecg)
        
        # concatenation
        merged = None
        if self.useAbp and self.useEeg and self.useEcg:
            merged = torch.cat((abp, eeg, ecg), dim=1)
        elif self.useAbp and self.useEeg:
            merged = torch.cat((abp, eeg), dim=1)
        elif self.useAbp and self.useEcg:
            merged = torch.cat((abp, ecg), dim=1)
        elif self.useEeg and self.useEcg:
            merged = torch.cat((eeg, ecg), dim=1)
        elif self.useAbp:
            merged = abp
        elif self.useEeg:
            merged = eeg
        elif self.useEcg:
            merged = ecg

        totalLen = np.prod(merged.shape)
        merged = torch.reshape(merged, (batchSize, int(totalLen / batchSize)))
        out = self.fullLinear1(merged)
        out = self.fullLinear2(out)
        out = self.sigmoid(out)

        out = torch.nan_to_num(out)
        return out

Training¶

As discussed earlier, our model uses binary cross entropy as the loss function and Adam as the optimizer, consistent with the original study. The learning rate is set at 0.0001, and training is configured to run for up to 100 epochs, with early stopping implemented if no improvement in loss is observed over five consecutive epochs.

In [71]:
LEARNING_RATE = 0.0001
PATIENCE=15

useAbp = True
useEeg = False
useEcg = False
# enable only a single ablation
useAblationSixResidualBlocks = False
useAblationOneResidualBlocks = False
useAblationIgnoreSkipConnection = False

# to be composed by checking config booleans
experimentName = "DEFAULT"

# enforce single ablation
if useAblationSixResidualBlocks and useAblationOneResidualBlocks and useAblationIgnoreSkipConnection:
    # if all 3 selected, only choose one residual block
    useAblationSixResidualBlocks = False
    useAblationIgnoreSkipConnection = False
elif useAblationSixResidualBlocks and useAblationOneResidualBlocks:
    # if 6 and 1, only choose 1
    useAblationSixResidualBlocks = False
elif useAblationSixResidualBlocks and useAblationIgnoreSkipConnection:
    # if six and skip, only choose six
    useAblationIgnoreSkipConnection = False
elif useAblationOneResidualBlocks and useAblationIgnoreSkipConnection:
    # if one and skip, only choose six
    useAblationIgnoreSkipConnection = False

if useAbp and useEeg and useEcg:
    experimentName = "ABP_EEG_ECG"
elif useAbp and useEeg:
    experimentName = "ABP_EEG"
elif useAbp and useEcg:
    experimentName = "ABP_ECG"
elif useEeg and useEcg:
    experimentName = "EEG_ECG"
elif useAbp:
    experimentName = "ABP"
elif useEeg:
    experimentName = "EEG"
elif useEcg:
    experimentName = "ECG"

if useAblationSixResidualBlocks:
    experimentName = f"{experimentName}_ABLATION_SIX_RESIDUAL_BLOCKS"
if useAblationOneResidualBlocks:
    experimentName = f"{experimentName}_ABLATION_ONE_RESIDUAL_BLOCK"
if useAblationIgnoreSkipConnection:
    experimentName = f"{experimentName}_ABLATION_IGNORE_SKIP_CONNECTION"

experimentName = f"{experimentName}_{PREDICTION_WINDOW}_MINS"

if MAX_CASES is not None:
    experimentName = f"{experimentName}_MAX_{MAX_CASES}_CASES"

print(f"Preparing to run experiment titled {experimentName}")

model = HypotensionCNN(useAbp, useEeg, useEcg, useAblationSixResidualBlocks, useAblationOneResidualBlocks, useAblationIgnoreSkipConnection)
loss_func = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if (torch.backends.mps.is_available() and torch.backends.mps.is_built()) else "cpu")
print(f"Using device: {device}")
model = model.to(device)

def train_model_one_iter(model, loss_func, optimizer, train_loader):
    model.train()
    train_losses = []
    for abp, ecg, eeg, label in tqdm(train_loader):
        batch = len(abp)

        abp = abp.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        ecg = ecg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        eeg = eeg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        label = label.type(torch.float).reshape(batch, 1).to(device)

        optimizer.zero_grad()
        mdl = model(abp, eeg, ecg)
        loss = loss_func(torch.nan_to_num(mdl), label)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.cpu().data.numpy())
    return np.mean(train_losses)

def evaluate_model(model, loss_func, val_loader):
    model.eval()
    val_losses = []
    for abp, ecg, eeg, label in tqdm(val_loader):
        batch = len(abp)

        abp = abp.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        ecg = ecg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        eeg = eeg.reshape(batch, 1, -1).type(torch.FloatTensor).to(device)
        label = label.type(torch.float).reshape(batch, 1).to(device)

        mdl = model(abp, eeg, ecg)
        loss = loss_func(torch.nan_to_num(mdl), label)
        val_losses.append(loss.cpu().data.numpy())
    return np.mean(val_losses)


# Training loop
max_epochs = 100
best_epoch = 0
train_losses = []
val_losses = []
best_loss = float('inf')
no_improve_epochs = 0
model_path = os.path.join(VITAL_MODELS, f"{experimentName}.model")

all_models = []

for i in range(max_epochs):
    # Train the model and get the training loss
    train_loss = train_model_one_iter(model, loss_func, optimizer, train_loader)
    train_losses.append(train_loss)
    # Calculate validate loss
    val_loss = evaluate_model(model, loss_func, val_loader)
    val_losses.append(val_loss)
    print(f"[{datetime.now()}] Completed epoch {i} with training loss {train_loss:.8f}, validation loss {val_loss:.8f}")

    # Save all intermediary models.
    tmp_model_path = os.path.join(VITAL_MODELS, f"{experimentName}_{i:04d}.model")
    torch.save(model.state_dict(), tmp_model_path)
    all_models.append(tmp_model_path)
  
    # Check if validation loss has improved
    if val_loss < best_loss:
        best_epoch = i
        best_loss = val_loss
        no_improve_epochs = 0
        torch.save(model.state_dict(), model_path)
        print(f"Validation loss improved to {val_loss:.8f}. Model saved.")
    else:
        no_improve_epochs += 1
        print(f"No improvement in validation loss. {no_improve_epochs} epochs without improvement.")

    # exit early if no improvement in loss over last 'patience' epochs
    if no_improve_epochs >= PATIENCE:
        print("Early stopping due to no improvement in validation loss.")
        break

# Load best model from disk
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print(f"Loaded best model from disk from epoch {best_epoch}.")
else:
    print("No saved model found for f{experimentName}.")

model.train(False)
Preparing to run experiment titled ABP_10_MINS
Using device: mps
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:44<00:00,  3.81it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.73it/s]
[2024-04-29 06:54:47.943892] Completed epoch 0 with training loss 0.35997611, validation loss 0.32109410
Validation loss improved to 0.32109410. Model saved.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:42<00:00,  3.89it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.95it/s]
[2024-04-29 06:56:44.294327] Completed epoch 1 with training loss 0.35201815, validation loss 0.32855344
No improvement in validation loss. 1 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:41<00:00,  3.93it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.95it/s]
[2024-04-29 06:58:39.522244] Completed epoch 2 with training loss 0.35141850, validation loss 0.33173737
No improvement in validation loss. 2 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:40<00:00,  3.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.92it/s]
[2024-04-29 07:00:34.490247] Completed epoch 3 with training loss 0.35028836, validation loss 0.32163808
No improvement in validation loss. 3 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:40<00:00,  3.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.91it/s]
[2024-04-29 07:02:29.514020] Completed epoch 4 with training loss 0.34993768, validation loss 0.32028395
Validation loss improved to 0.32028395. Model saved.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:40<00:00,  3.96it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.96it/s]
[2024-04-29 07:04:24.074576] Completed epoch 5 with training loss 0.35001799, validation loss 0.34205952
No improvement in validation loss. 1 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:41<00:00,  3.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.95it/s]
[2024-04-29 07:06:19.091746] Completed epoch 6 with training loss 0.34890786, validation loss 0.32189128
No improvement in validation loss. 2 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:41<00:00,  3.93it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.96it/s]
[2024-04-29 07:08:14.240750] Completed epoch 7 with training loss 0.34830025, validation loss 0.32488021
No improvement in validation loss. 3 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:41<00:00,  3.91it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.94it/s]
[2024-04-29 07:10:10.026700] Completed epoch 8 with training loss 0.34895378, validation loss 0.32224044
No improvement in validation loss. 4 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:41<00:00,  3.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.71it/s]
[2024-04-29 07:12:05.819288] Completed epoch 9 with training loss 0.34647059, validation loss 0.32124263
No improvement in validation loss. 5 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:41<00:00,  3.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.94it/s]
[2024-04-29 07:14:00.891475] Completed epoch 10 with training loss 0.34644493, validation loss 0.32086855
No improvement in validation loss. 6 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:40<00:00,  3.94it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.95it/s]
[2024-04-29 07:15:55.838995] Completed epoch 11 with training loss 0.34546638, validation loss 0.32277697
No improvement in validation loss. 7 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:41<00:00,  3.91it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.95it/s]
[2024-04-29 07:17:51.722343] Completed epoch 12 with training loss 0.34526232, validation loss 0.32851338
No improvement in validation loss. 8 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:40<00:00,  3.95it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.96it/s]
[2024-04-29 07:19:46.539105] Completed epoch 13 with training loss 0.34514406, validation loss 0.32400620
No improvement in validation loss. 9 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:40<00:00,  3.95it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  4.96it/s]
[2024-04-29 07:21:41.204824] Completed epoch 14 with training loss 0.34440288, validation loss 0.32309961
No improvement in validation loss. 10 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:40<00:00,  3.97it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.89it/s]
[2024-04-29 07:23:35.564550] Completed epoch 15 with training loss 0.34378919, validation loss 0.32594782
No improvement in validation loss. 11 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [03:05<00:00,  2.15it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.92it/s]
[2024-04-29 07:26:54.907861] Completed epoch 16 with training loss 0.34301734, validation loss 0.32354611
No improvement in validation loss. 12 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:42<00:00,  3.88it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.92it/s]
[2024-04-29 07:28:51.681239] Completed epoch 17 with training loss 0.34195307, validation loss 0.32368425
No improvement in validation loss. 13 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:43<00:00,  3.86it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.91it/s]
[2024-04-29 07:30:48.943512] Completed epoch 18 with training loss 0.34198403, validation loss 0.32193542
No improvement in validation loss. 14 epochs without improvement.
100%|█████████████████████████████████████████████████████████████████████████████████████| 398/398 [01:42<00:00,  3.89it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.90it/s]
[2024-04-29 07:32:45.458222] Completed epoch 19 with training loss 0.34278908, validation loss 0.33414593
No improvement in validation loss. 15 epochs without improvement.
Early stopping due to no improvement in validation loss.
Loaded best model from disk from epoch 4.
Out[71]:
HypotensionCNN(
  (abpBlock1): ResidualBlock(
    (bn1): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(1, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (residualConv): Conv1d(1, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (abpBlock2): ResidualBlock(
    (bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (residualConv): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
  )
  (abpBlock3): ResidualBlock(
    (bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (residualConv): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (abpBlock4): ResidualBlock(
    (bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (residualConv): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
  )
  (abpBlock5): ResidualBlock(
    (bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (bn2): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (residualConv): Conv1d(2, 2, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (abpBlock6): ResidualBlock(
    (bn1): BatchNorm1d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(2, 4, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(4, 4, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
    (residualConv): Conv1d(2, 4, kernel_size=(15,), stride=(1,), padding=(7,), bias=False)
  )
  (abpBlock7): ResidualBlock(
    (bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (residualConv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (abpBlock8): ResidualBlock(
    (bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (residualConv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
  )
  (abpBlock9): ResidualBlock(
    (bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (residualConv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (downsample): MaxPool1d(kernel_size=2, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (abpBlock10): ResidualBlock(
    (bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (residualConv): Conv1d(4, 4, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
  )
  (abpBlock11): ResidualBlock(
    (bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(4, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (bn2): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(6, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (residualConv): Conv1d(4, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (downsample): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (abpBlock12): ResidualBlock(
    (bn1): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (conv1): Conv1d(6, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (bn2): BatchNorm1d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(6, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
    (residualConv): Conv1d(6, 6, kernel_size=(7,), stride=(1,), padding=(3,), bias=False)
  )
  (abpFc): Linear(in_features=2814, out_features=32, bias=True)
  (fullLinear1): Linear(in_features=32, out_features=16, bias=True)
  (fullLinear2): Linear(in_features=16, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

Plot the training and validation losses after each epoch:

In [72]:
# Create x-axis values for epochs
epochs = range(0, len(train_losses))

plt.figure(figsize=(16, 9))

# Plot the training and validation losses
plt.plot(epochs, train_losses, 'b', label='Training Loss')
plt.plot(epochs, val_losses, 'r', label='Validation Loss')

# Add a vertical bar at the best_epoch
plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best Epoch')

# Shade everything to the right of the best_epoch a light red
plt.axvspan(best_epoch, max(epochs), facecolor='r', alpha=0.1)

# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Losses')

# Add legend
plt.legend(loc='upper right')

# Show the plot
plt.show()
In [73]:
def eval_model(model, dataloader):
    model.eval()
    model = model.to(device)
    total_loss = 0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for abp, ecg, eeg, label in tqdm(dataloader):
            batch = len(abp)
    
            abp = torch.nan_to_num(abp.reshape(batch, 1, -1)).type(torch.FloatTensor).to(device)
            ecg = torch.nan_to_num(ecg.reshape(batch, 1, -1)).type(torch.FloatTensor).to(device)
            eeg = torch.nan_to_num(eeg.reshape(batch, 1, -1)).type(torch.FloatTensor).to(device)
            label = label.type(torch.float).reshape(batch, 1).to(device)
   
            pred = model(abp, eeg, ecg)
            loss = loss_func(pred, label)
            total_loss += loss.item()

            all_predictions.append(pred.detach().cpu().numpy())
            all_labels.append(label.detach().cpu().numpy())

    # Flatten the lists
    all_predictions = np.concatenate(all_predictions).flatten()
    all_labels = np.concatenate(all_labels).flatten()

    # Calculate AUROC and AUPRC
    # y_true, y_pred
    auroc = roc_auc_score(all_labels, all_predictions)
    precision, recall, _ = precision_recall_curve(all_labels, all_predictions)
    auprc = auc(recall, precision)

    # Determine the optimal threshold, which is argmin(abs(sensitivity - specificity)) per the paper
    thresholds = np.linspace(0, 1, 101) # 0 to 1 in 0.01 steps
    min_diff = float('inf')
    optimal_sensitivity = None
    optimal_specificity = None
    optimal_threshold = None

    for threshold in thresholds:
        all_predictions_binary = (all_predictions > threshold).astype(int)

        tn, fp, fn, tp = confusion_matrix(all_labels, all_predictions_binary).ravel()
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        diff = abs(sensitivity - specificity)

        if diff < min_diff:
            min_diff = diff
            optimal_threshold = threshold
            optimal_sensitivity = sensitivity
            optimal_specificity = specificity

    avg_loss = total_loss / len(dataloader)
    return all_predictions, all_labels, avg_loss, auroc, auprc, optimal_sensitivity, optimal_specificity, optimal_threshold

# validation loop
valid_predictions, valid_labels, valid_loss, valid_auroc, valid_auprc, valid_sensitivity, valid_specificity, valid_threshold = eval_model(model, val_loader)

# test loop
test_predictions, test_labels, test_loss, test_auroc, test_auprc, test_sensitivity, test_specificity, test_threshold = eval_model(model, test_loader)

print(f'Best Epoch: {best_epoch}')
print()
print(f"Validation predictions: {valid_predictions}")
print(f"Validation labels: {valid_labels}")
print(f"Validation loss: {valid_loss}")
print(f"Validation AUROC: {valid_auroc}")
print(f"Validation AUPRC: {valid_auprc}")
print(f"Validation Sensitivity: {valid_sensitivity}")
print(f"Validation Specificity: {valid_specificity}")
print(f"Validation Threshold: {valid_threshold}")
print()
print(f"Test predictions: {test_predictions}")
print(f"Test labels: {test_labels}")
print(f"Test loss: {test_loss}")
print(f"Test AUROC: {test_auroc}")
print(f"Test AUPRC: {test_auprc}")
print(f"Test Sensitivity: {test_sensitivity}")
print(f"Test Specificity: {test_specificity}")
print(f"Test Threshold: {test_threshold}")
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:13<00:00,  5.02it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:42<00:00,  4.82it/s]
Best Epoch: 4

Validation predictions: [0.11085445 0.09836218 0.05369908 ... 0.05172839 0.0518949  0.08113384]
Validation labels: [0. 0. 0. ... 1. 1. 0.]
Validation loss: 0.3205312129811964
Validation AUROC: 0.7119086032770398
Validation AUPRC: 0.2569351108811815
Validation Sensitivity: 0.6796714579055442
Validation Specificity: 0.6535714285714286
Validation Threshold: 0.11

Test predictions: [0.13859975 0.18436193 0.15511528 ... 0.09794634 0.18674852 0.08707681]
Test labels: [0. 0. 0. ... 0. 0. 0.]
Test loss: 0.34488859535286237
Test AUROC: 0.7127553382185996
Test AUPRC: 0.28786452822299974
Test Sensitivity: 0.6726596404215747
Test Specificity: 0.6407673440784863
Test Threshold: 0.11
In [74]:
PRINT_DETAILED = False

val_aurocs = []
val_auprcs = []

test_aurocs = []
test_auprcs = []

for all_mod in all_models:
    model.load_state_dict(torch.load(all_mod))
    model.train(False)
    
    # validation loop
    valid_predictions, valid_labels, valid_loss, valid_auroc, valid_auprc, valid_sensitivity, valid_specificity, valid_threshold = eval_model(model, val_loader)
    val_aurocs.append(valid_auroc)
    val_auprcs.append(valid_auprc)
    
    # test loop
    test_predictions, test_labels, test_loss, test_auroc, test_auprc, test_sensitivity, test_specificity, test_threshold = eval_model(model, test_loader)
    test_aurocs.append(test_auroc)
    test_auprcs.append(test_auprc)
    
    print(f'Model: {all_mod}')
    if PRINT_DETAILED:
        print(f"Validation predictions: {valid_predictions}")
        print(f"Validation labels: {valid_labels}")
    print(f"Validation loss: {valid_loss}")
    print(f"Validation AUROC: {valid_auroc}")
    print(f"Validation AUPRC: {valid_auprc}")
    print(f"Validation Sensitivity: {valid_sensitivity}")
    print(f"Validation Specificity: {valid_specificity}")
    print(f"Validation Threshold: {valid_threshold}")
    print()
    if PRINT_DETAILED:
        print(f"Test predictions: {test_predictions}")
        print(f"Test labels: {test_labels}")
    print(f"Test loss: {test_loss}")
    print(f"Test AUROC: {test_auroc}")
    print(f"Test AUPRC: {test_auprc}")
    print(f"Test Sensitivity: {test_sensitivity}")
    print(f"Test Specificity: {test_specificity}")
    print(f"Test Threshold: {test_threshold}")
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.81it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.94it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0000.model
Validation loss: 0.32120310871497443
Validation AUROC: 0.711436638310355
Validation AUPRC: 0.25366662377222426
Validation Sensitivity: 0.6509240246406571
Validation Specificity: 0.6948979591836735
Validation Threshold: 0.12

Test loss: 0.34487133122542324
Test AUROC: 0.7142997059629822
Test AUPRC: 0.2849495689025967
Test Sensitivity: 0.6416615003099814
Test Specificity: 0.679134548002803
Test Threshold: 0.12
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.87it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.93it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0001.model
Validation loss: 0.32897263743739197
Validation AUROC: 0.712564954113062
Validation AUPRC: 0.25567011233319714
Validation Sensitivity: 0.6386036960985626
Validation Specificity: 0.7056122448979592
Validation Threshold: 0.08

Test loss: 0.357669612137126
Test AUROC: 0.7143868407138739
Test AUPRC: 0.28819171124772985
Test Sensitivity: 0.6311221326720396
Test Specificity: 0.6896461107217939
Test Threshold: 0.08
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.81it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.93it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0002.model
Validation loss: 0.3318474776502969
Validation AUROC: 0.7119641285672379
Validation AUPRC: 0.25578035385727543
Validation Sensitivity: 0.6611909650924025
Validation Specificity: 0.6790816326530612
Validation Threshold: 0.07

Test loss: 0.361610484108621
Test AUROC: 0.7138609095857892
Test AUPRC: 0.2880802235945207
Test Sensitivity: 0.6584004959702418
Test Specificity: 0.6610896986685354
Test Threshold: 0.07
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.87it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:40<00:00,  4.98it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0003.model
Validation loss: 0.3215533004722733
Validation AUROC: 0.7111852030339857
Validation AUPRC: 0.25521367753007196
Validation Sensitivity: 0.6550308008213552
Validation Specificity: 0.6793367346938776
Validation Threshold: 0.1

Test loss: 0.3478605173089925
Test AUROC: 0.7122484143593291
Test AUPRC: 0.2887601993140735
Test Sensitivity: 0.6540607563546187
Test Specificity: 0.6656447091800981
Test Threshold: 0.1
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.96it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0004.model
Validation loss: 0.3203725022250328
Validation AUROC: 0.7119086032770398
Validation AUPRC: 0.2569351108811815
Validation Sensitivity: 0.6796714579055442
Validation Specificity: 0.6535714285714286
Validation Threshold: 0.11

Test loss: 0.34488859535286237
Test AUROC: 0.7127553382185996
Test AUPRC: 0.28786452822299974
Test Sensitivity: 0.6726596404215747
Test Specificity: 0.6407673440784863
Test Threshold: 0.11
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.97it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0005.model
Validation loss: 0.3417626991868019
Validation AUROC: 0.7127179105728534
Validation AUPRC: 0.2569864666270052
Validation Sensitivity: 0.6427104722792608
Validation Specificity: 0.699234693877551
Validation Threshold: 0.06

Test loss: 0.374697232144136
Test AUROC: 0.7129801127489463
Test AUPRC: 0.2873913066126371
Test Sensitivity: 0.6422814631122132
Test Specificity: 0.6829011913104415
Test Threshold: 0.06
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.85it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:40<00:00,  4.98it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0006.model
Validation loss: 0.3218131687330163
Validation AUROC: 0.7113360641998072
Validation AUPRC: 0.2547411589271957
Validation Sensitivity: 0.6837782340862423
Validation Specificity: 0.6558673469387755
Validation Threshold: 0.13

Test loss: 0.3448963412905441
Test AUROC: 0.7099352568979008
Test AUPRC: 0.28423448610948093
Test Sensitivity: 0.6683199008059516
Test Specificity: 0.6420812894183602
Test Threshold: 0.13
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.97it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0007.model
Validation loss: 0.3249230145112328
Validation AUROC: 0.711227632736873
Validation AUPRC: 0.252094660778317
Validation Sensitivity: 0.6652977412731006
Validation Specificity: 0.6716836734693877
Validation Threshold: 0.15

Test loss: 0.3464355053489699
Test AUROC: 0.7101784956322383
Test AUPRC: 0.283654632868036
Test Sensitivity: 0.6546807191568506
Test Specificity: 0.65583391730904
Test Threshold: 0.15
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.84it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.97it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0008.model
Validation loss: 0.32182940300823987
Validation AUROC: 0.7116820496165612
Validation AUPRC: 0.25315562868234687
Validation Sensitivity: 0.6796714579055442
Validation Specificity: 0.6502551020408164
Validation Threshold: 0.1

Test loss: 0.34827894772238593
Test AUROC: 0.709968791150737
Test AUPRC: 0.28370605665594545
Test Sensitivity: 0.6701797892126472
Test Specificity: 0.6368255080588647
Test Threshold: 0.1
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.97it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0009.model
Validation loss: 0.3211978922287623
Validation AUROC: 0.7117276222604032
Validation AUPRC: 0.2535507943283328
Validation Sensitivity: 0.6652977412731006
Validation Specificity: 0.6678571428571428
Validation Threshold: 0.11

Test loss: 0.34649836359655156
Test AUROC: 0.7097432834828789
Test AUPRC: 0.2805295946533841
Test Sensitivity: 0.6596404215747055
Test Specificity: 0.6520672740014015
Test Threshold: 0.11
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.88it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:40<00:00,  4.98it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0010.model
Validation loss: 0.3207992671624474
Validation AUROC: 0.7102551544231656
Validation AUPRC: 0.2511244620519868
Validation Sensitivity: 0.6652977412731006
Validation Specificity: 0.6775510204081633
Validation Threshold: 0.12

Test loss: 0.34546954542690633
Test AUROC: 0.7081632906860906
Test AUPRC: 0.2784366124362746
Test Sensitivity: 0.6460012399256044
Test Specificity: 0.6610896986685354
Test Threshold: 0.12
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.97it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0011.model
Validation loss: 0.3225597726262134
Validation AUROC: 0.7093036290491557
Validation AUPRC: 0.2503529401185687
Validation Sensitivity: 0.6570841889117043
Validation Specificity: 0.698469387755102
Validation Threshold: 0.13

Test loss: 0.34596408501851794
Test AUROC: 0.7082369574293658
Test AUPRC: 0.2744195856506264
Test Sensitivity: 0.629262244265344
Test Specificity: 0.6758058864751226
Test Threshold: 0.13
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.95it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0012.model
Validation loss: 0.3283592664461205
Validation AUROC: 0.7074870091773877
Validation AUPRC: 0.24955493770504691
Validation Sensitivity: 0.6878850102669405
Validation Specificity: 0.6303571428571428
Validation Threshold: 0.08

Test loss: 0.35763226091569544
Test AUROC: 0.7046014099700619
Test AUPRC: 0.2722217538060982
Test Sensitivity: 0.6869187848729076
Test Specificity: 0.616240364400841
Test Threshold: 0.08
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.97it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0013.model
Validation loss: 0.32360575704470923
Validation AUROC: 0.7064833633658801
Validation AUPRC: 0.248023877674416
Validation Sensitivity: 0.6694045174537988
Validation Specificity: 0.6548469387755103
Validation Threshold: 0.1

Test loss: 0.34992039638261
Test AUROC: 0.706516718141971
Test AUPRC: 0.2694860144411936
Test Sensitivity: 0.6559206447613143
Test Specificity: 0.6391030133146461
Test Threshold: 0.1
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.96it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0014.model
Validation loss: 0.3227913468212321
Validation AUROC: 0.7061182583916522
Validation AUPRC: 0.24650987637429228
Validation Sensitivity: 0.6673511293634496
Validation Specificity: 0.6686224489795919
Validation Threshold: 0.11

Test loss: 0.34880487606221555
Test AUROC: 0.7032315018001513
Test AUPRC: 0.26704767771647275
Test Sensitivity: 0.6404215747055176
Test Specificity: 0.6525052557813595
Test Threshold: 0.11
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.96it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0015.model
Validation loss: 0.32583489063857257
Validation AUROC: 0.7046211708502703
Validation AUPRC: 0.24443694759039475
Validation Sensitivity: 0.6899383983572895
Validation Specificity: 0.6329081632653061
Validation Threshold: 0.09

Test loss: 0.3535620750238498
Test AUROC: 0.7027570260640704
Test AUPRC: 0.26430333373648507
Test Sensitivity: 0.6726596404215747
Test Specificity: 0.62114576033637
Test Threshold: 0.09
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.95it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0016.model
Validation loss: 0.32319997056670813
Validation AUROC: 0.7031348216904832
Validation AUPRC: 0.24161776067140825
Validation Sensitivity: 0.675564681724846
Validation Specificity: 0.6400510204081633
Validation Threshold: 0.11

Test loss: 0.34859373267082605
Test AUROC: 0.700473166949857
Test AUPRC: 0.2622604317828593
Test Sensitivity: 0.6689398636081835
Test Specificity: 0.626927119831815
Test Threshold: 0.11
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.97it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0017.model
Validation loss: 0.32341467491958453
Validation AUROC: 0.7008171646482001
Validation AUPRC: 0.23835035504281932
Validation Sensitivity: 0.675564681724846
Validation Specificity: 0.6364795918367347
Validation Threshold: 0.11

Test loss: 0.3483743406832218
Test AUROC: 0.6993296896580038
Test AUPRC: 0.25807789364078204
Test Sensitivity: 0.666460012399256
Test Specificity: 0.623510861948143
Test Threshold: 0.11
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.90it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.96it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0018.model
Validation loss: 0.3222654038581295
Validation AUROC: 0.7046531240833089
Validation AUPRC: 0.24356908596809165
Validation Sensitivity: 0.6899383983572895
Validation Specificity: 0.6311224489795918
Validation Threshold: 0.11

Test loss: 0.3467049452604032
Test AUROC: 0.7027705755314106
Test AUPRC: 0.26205454326659183
Test Sensitivity: 0.6707997520148791
Test Specificity: 0.6201822004204625
Test Threshold: 0.11
100%|███████████████████████████████████████████████████████████████████████████████████████| 69/69 [00:14<00:00,  4.89it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:40<00:00,  4.98it/s]
Model: ./vitaldb_cache/models/ABP_10_MINS_0019.model
Validation loss: 0.334323824747749
Validation AUROC: 0.6886089343334869
Validation AUPRC: 0.2258062720327461
Validation Sensitivity: 0.6344969199178645
Validation Specificity: 0.6632653061224489
Validation Threshold: 0.16

Test loss: 0.356246235543022
Test AUROC: 0.6835170811264989
Test AUPRC: 0.2401601032518599
Test Sensitivity: 0.6137631742095474
Test Specificity: 0.651541695865452
Test Threshold: 0.16
In [75]:
# Create x-axis values for epochs
epochs = range(0, len(val_aurocs))

# Find model with highest AUROC
np_test_aurocs = np.array(test_aurocs)
test_auroc_idx = np.argmax(np_test_aurocs)

print(f'Epoch with best Validation Loss:  {best_epoch:3}, {val_losses[best_epoch]:.4}')
print(f'Epoch with best model Test AUROC: {test_auroc_idx:3}, {np.max(np_test_aurocs):.4}')
print(f'Best Model on Validation Loss:    {all_models[test_auroc_idx]}')
print(f'Best Model on Test AUROC:         {all_models[best_epoch]}')

plt.figure(figsize=(16, 9))

# Plot the training and validation losses
plt.plot(epochs, val_aurocs, 'C0', label='AUROC - Validation')
plt.plot(epochs, test_aurocs, 'C1', label='AUROC - Test')

plt.plot(epochs, val_auprcs, 'C2', label='AUPRC - Validation')
plt.plot(epochs, test_auprcs, 'C3', label='AUPRC - Test')

# Add a vertical bar at the best_epoch
plt.axvline(x=best_epoch, color='g', linestyle='--', label='Best Epoch - Validation Loss')
plt.axvline(x=test_auroc_idx, color='maroon', linestyle='--', label='Best Epoch - Test AUROC')

# Shade everything to the right of the best_model a light red
plt.axvspan(test_auroc_idx, max(epochs), facecolor='r', alpha=0.1)

# Add labels and title
plt.xlabel('Epochs')
plt.ylabel('AUROC / AUPRC')
plt.title('Validation and Test AUROC by Model Iteration Across Training')

# Add legend
plt.legend(loc='right')

# Show the plot
plt.show()
Epoch with best Validation Loss:    4, 0.3203
Epoch with best model Test AUROC:   1, 0.7144
Best Model on Validation Loss:    ./vitaldb_cache/models/ABP_10_MINS_0001.model
Best Model on Test AUROC:         ./vitaldb_cache/models/ABP_10_MINS_0004.model

AUROC / AUPRC - Model with Best Validation Loss¶

In [76]:
best_model_val_loss = all_models[best_epoch]

print(f'Best Model Based on Validation Loss: {best_model_val_loss}')
model.load_state_dict(torch.load(best_model_val_loss))
model.train(False)

(best_model_val_test_predictions, best_model_val_test_labels, test_loss, 
 test_auroc, best_model_val_test_auprc, test_sensitivity, test_specificity, best_model_val_test_threshold) \
    = eval_model(model, test_loader)
Best Model Based on Validation Loss: ./vitaldb_cache/models/ABP_10_MINS_0004.model
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.97it/s]
In [77]:
# y_test, y_pred
display = RocCurveDisplay.from_predictions(
    best_model_val_test_labels,
    best_model_val_test_predictions,
    plot_chance_level=True
)
plt.show()
In [78]:
roc_auc_score(best_model_val_test_labels, best_model_val_test_predictions)
Out[78]:
0.7127553382185996
In [79]:
best_model_val_test_predictions_binary = \
    (best_model_val_test_predictions > best_model_val_test_threshold).astype(int)

# y_test, y_pred
display = PrecisionRecallDisplay.from_predictions(
    best_model_val_test_labels, 
    best_model_val_test_predictions_binary,
    plot_chance_level=True
)
plt.show()
In [80]:
# y_test, y_pred
display = PrecisionRecallDisplay.from_predictions(
    best_model_val_test_labels, 
    best_model_val_test_predictions,
    plot_chance_level=True
)
plt.show()
In [81]:
best_model_val_test_auprc
Out[81]:
0.28786452822299974

AUROC / AUPRC - Model with Best AUROC¶

In [82]:
best_model_auroc = all_models[test_auroc_idx]

print(f'Best Model Based on Model AUROC: {best_model_auroc}')
model.load_state_dict(torch.load(best_model_auroc))
model.train(False)

(best_model_auroc_test_predictions, best_model_auroc_test_labels, test_loss, 
 test_auroc, best_model_auroc_test_auprc, test_sensitivity, test_specificity, best_model_auroc_test_threshold) \
    = eval_model(model, test_loader)
Best Model Based on Model AUROC: ./vitaldb_cache/models/ABP_10_MINS_0001.model
100%|█████████████████████████████████████████████████████████████████████████████████████| 204/204 [00:41<00:00,  4.95it/s]
In [83]:
# y_test, y_pred
display = RocCurveDisplay.from_predictions(
    best_model_auroc_test_labels,
    best_model_auroc_test_predictions,
    plot_chance_level=True
)
plt.show()
In [84]:
roc_auc_score(best_model_auroc_test_labels, best_model_auroc_test_predictions)
Out[84]:
0.7143868407138739
In [85]:
best_model_auroc_test_predictions_binary = \
    (best_model_auroc_test_predictions > best_model_auroc_test_threshold).astype(int)

# y_test, y_pred
display = PrecisionRecallDisplay.from_predictions(
    best_model_auroc_test_labels, 
    best_model_auroc_test_predictions_binary,
    plot_chance_level=True
)
plt.show()
In [86]:
# y_test, y_pred
display = PrecisionRecallDisplay.from_predictions(
    best_model_auroc_test_labels, 
    best_model_auroc_test_predictions,
    plot_chance_level=True
)
plt.show()
In [87]:
best_model_auroc_test_auprc
Out[87]:
0.28819171124772985

Results (Planned results for Draft submission)¶

When we complete our experiments, we will build comparison tables that compare a set of measures for each experiment performed. The full set of experiments and measures are listed below.

Results from Final Rubrik¶

  • Table of results (no need to include additional experiments, but main reproducibility result should be included)
  • All claims should be supported by experiment results
  • Discuss with respect to the hypothesis and results from the original paper
  • Experiments beyond the original paper
    • Each experiment should include results and a discussion
  • Ablation Study.

Experiments¶

  • ABP only
  • ECG only
  • EEG only
  • ABP + ECG
  • ABP + EEG
  • ECG + EEG
  • ABP + ECG + EEG

Note: each experiment will be repeated with the following time-to-IOH-event durations:

  • 3 minutes
  • 5 minutes
  • 10 minutes
  • 15 minutes

Note: the above list of experiments will be performed if there is sufficient time and gpu capability to complete that before the submission deadline. Should we experience any constraints on this front, we will reduce our experimental coverage to the following 4 core experiments that are necessary to measure the hypotheses included at the head of this report:

  • ABP only @ 3 minutes
  • ABP + ECG @ 3 minutes
  • ABP + EEG @ 3 minutes
  • ABP + ECG + EEG @ 3 minutes

For additional details please review the "Planned Actions" in the Discussion section of this report.

Measures¶

  • AUROC
  • AUPRC
  • Sensitivity
  • Specificity
  • Threshold
  • Loss Shrinkage

[ TODO for final report - collect data for all measures listed above. ]

[ TODO for final report - generate ROC and PRC plots for each experiment ]

We are collecting a broad set of measures across each experiment in order to perform a comprehensive comparison of all measures listed across all comparable experiments executed in the original paper. However, our key experimental results will be focused on a subset of these results that address the main experiments defined at the beginning of this notebook.

The key experimental result measures will be as follows:

  • For 3 minutes ahead of the predicted IOH event:
    • compare AUROC and AUPRC for ABP only vs ABP+ECG
    • compare AUROC and AUPRC for ABP only vs ABP+EEG
    • compare AUROC and AUPRC for ABP only vs ABP+ECG+EEG

Model comparison¶

The following table is Table 3 from the original paper which presents the measured values for each signal combination across each of the four temporal predictive categories:

Area under the Receiver-operating Characteristic Curve, Area under the Precision-Recall Curve, Sensitivity, and Specificity of the model in predicting intraoperative hypotension

We have not yet completed the execution of the experiments necessary to determine our reproduced model performance in order determine whether our results are accurately representing those of the original paper. These details are expected to be included in the final report.

As of the draft submission, the reported evaluation measures of our model are too good to be true (all measures are 1.0). We suspect that there is data leakage in the dataset splitting process and will address this in time for the final report.

Discussion¶

Discussion (10) FROM FINAL RUBRIK¶

  • Implications of the experimental results, whether the original paper was reproducible, and if it wasn’t, what factors made it irreproducible
  • “What was easy”
  • “What was difficult”
  • Recommendations to the original authors or others who work in this area for improving reproducibility
  • (specific to our group) "I have communicated with Maciej during OH. The draft looks good and I would expect some explanations/analysis on the final report on why you get 1.0 as AUROC."
    • discuss our bug where we were believing we were sampling dozens of different patient samples but were just training the model on the same segments extracted from the same patient sample over and over. so we were massively overfitting our training data for one patient's data, then unwittingly using the same patient data for validation and testing, thus getting perfect classification during inference.

Feasibility of reproduction¶

Our assessment is that this paper will be reproducible. The outstanding risk is that each experiment can take up to 7 hours to run on hardware within the team (i.e., 7h to run ~70 epochs on a desktop with AMD Ryzen 7 3800X 8-core CPU w/ RTX 2070 SUPER GPU and 32GB RAM). There are a total of 28 experiments (7 different combinations of signal inputs, 4 different time horizons for each combination). Should our team find it not possible to complete the necessary experiments across all of the experiments represented in Table 3 of our selected paper, we will reduce the number of experiments to focus solely on the ones directly related to our hypotheses described in the beginning of this notebook (i.e., reduce the number of combinations of interest to 4: ABP alone, ABP+EEG, ABP+ECG, ABP+ECG+EEG). This will result in a new total of 16 experiments to run.

Planned ablations¶

Our proposal included a collection of potential ablations to be investigated:

  • Remove ResNet skip connection
  • Reduce # of residual blocks from 12 to 6
  • Reduce # of residual blocks from 12 to 1
  • Eliminate dropout from residual block
  • Max pooling configuration
    • smaller size/stride
    • eliminate max pooling

Given the amount of time required to conduct each experiment, our team intends to choose only a small number of ablations from this set. Further, we only intend to perform ablation analysis against the best performing signal combination and time horizon from the reproduction experiments. In order words, we intend to perform ablation analysis against the following training combinations, and only against the models trained with data measured 3 minutes prior to an IOH event:

  • ABP alone
  • ABP + ECG
  • ABP + EEG
  • ABP + ECG + EEG

Time and GPU resource permitting, we will complete a broader range of experiments. For additional details, please see the section below titled "Plans for next phase".

Nature of reproduced results¶

Our team intends to address the manner in which the experimental results align with the published results in the paper in the final submission of this report. The amount of time required to complete model training and result analysis during the preparation of the Draft notebook was not sufficient to complete a large number of experiments.

What was easy? What was difficult?¶

The difficult aspect of the preparation of this draft involved the data preprocessing.

  • First, the source data is unlabelled, so our team was responsible for implementing analysis methods for identifying positive (IOH event occurred) and negative (IOH event did not occur) by running a lookahead analysis of our input training set.
  • Second, the volume of raw data is in excess of 90GB. A non-trivial amount of compute was required to minify the input data to only include the data tracks of interest to our experiments (i.e., ABP, ECG, and EEG tracks).
  • Third, our team found it difficult to trace back to the definition of the jSQI signal quality index referenced in the paper. Multiple references through multiple papers needed to be traversed to understand which variant of the quality index
    • The only available source code related to the signal quality index as referenced by our paper in [5]. Source code was not directly linked from the paper, but the GitHub repository for the corresponding author for reference [5] did result in the identification of MATLAB source code for the signal quality index as described in the referenced paper. That code is available here: https://github.com/cliffordlab/PhysioNet-Cardiovascular-Signal-Toolbox/tree/master/Tools/BP_Tools
    • Our team had insufficient time to port this signal quality index to Python for use in our investigation, or to setup a MATLAB environment in which to assess our source data using the above MATLAB functions, but we expect to complete this as part of our final report.

Suggestions to paper author¶

The most notable suggestion would be to correct the hyperparameters published in Supplemental Table 1. Specifically, the output size for residual blocks 11 and 12 for the ECG and ABP data sets was 496x6. This is a typo, and should read 469x6. This typo became apparent when operating the size down operation within Residual Block 11 and recognizing the tensor dimensions were misaligned.

Additionally, more explicit references to the signal quality index assessment tools should be added. Our team could not find a reference to the MATLAB source code as described in reference [3], and had to manually discover the GitHub profile for the lab of the corresponding author of reference [3] in order to find MATLAB source that corresponded to the metrics described therein.

Plans for next phase¶

Our team plans to accomplish the following goals in service of preparing the Final Report:

  • Implement the jSQI filter to remove any training data with aberrent signal quality per the threshold defined in our original paper.
  • Execute the following experiments:
    • Measure predictive quality of the model trained solely with ABP data at 3 minutes prior to IOH events.
    • Measure predictive quality of the model trained with ABP+ECG data at 3 minutes prior to IOH events.
    • Measure predictive quality of the model trained with ABP+EEG data at 3 minutes prior to IOH events.
    • Measure predictive quality of the model trained with ABP+ECG+EEG data at 3 minutes prior to IOH events.
  • Gather our measures for these experiments and perform a comparison against the published results from our selected paper and determine whether or not we are succesfully reproducing the results outlined in the paper.
  • Ablation analysis:
    • Execute the following ablation experiments:
      • Repeat the four experiments described above while reducing the numnber of residual blocks in the model from 12 to 6.
  • Time- and/or GPU-resource permitting, we will complete the remaining 24 experiments as described in the paper:
    • Measure predictive quality of the model trained solely with ABP data at 5, 10, and 15 minutes prior to IOH events.
    • Measure predictive quality of the model trained with ABP+ECG data at 5, 10, and 15 minutes prior to IOH events.
    • Measure predictive quality of the model trained with ABP+EEG data at 5, 10, and 15 minutes prior to IOH events.
    • Measure predictive quality of the model trained with ABP+ECG+EEG data at 5, 10, and 15 minutes prior to IOH events.
    • Measure predictive quality of the model trained solely with ECG data at 3, 5, 10, and 15 minutes prior to IOH events.
    • Measure predictive quality of the model trained solely with EEG data at 3, 5, 10, and 15 minutes prior to IOH events.
    • Measure predictive quality of the model trained with ECG+EEG data at 3, 5, 10, and 15 minutes prior to IOH events.
    • Additional ablation experiments:
      • For the four core experiments (ABP, ABP+ECG, ABP+EEG, ABP+ECG+EEG each trained on event data occurring 3 minutes prior to IOH events), perform the following ablations:
        • Repeat experiment while eliminating dropout from every residual block
        • Repeat experiment while removing the skip connection from every residual block
        • Repeat the four experiments described above while reducing the numnber of residual blocks in the model from 12 to 1.

References¶

  1. Jo Y-Y, Jang J-H, Kwon J-m, Lee H-C, Jung C-W, Byun S, et al. “Predicting intraoperative hypotension using deep learning with waveforms of arterial blood pressure, electroencephalogram, and electrocardiogram: Retrospective study.” PLoS ONE, (2022) 17(8): e0272055 https://doi.org/10.1371/journal.pone.0272055
  2. Hatib, Feras, Zhongping J, Buddi S, Lee C, Settels J, Sibert K, Rhinehart J, Cannesson M “Machine-learning Algorithm to Predict Hypotension Based on High-fidelity Arterial Pressure Waveform Analysis” Anesthesiology (2018) 129:4 https://doi.org/10.1097/ALN.0000000000002300
  3. Bao, X., Kumar, S.S., Shah, N.J. et al. "AcumenTM hypotension prediction index guidance for prevention and treatment of hypotension in noncardiac surgery: a prospective, single-arm, multicenter trial." Perioperative Medicine (2024) 13:13 https://doi.org/10.1186/s13741-024-00369-9
  4. Lee, HC., Park, Y., Yoon, S.B. et al. VitalDB, a high-fidelity multi-parameter vital signs database in surgical patients. Sci Data 9, 279 (2022). https://doi.org/10.1038/s41597-022-01411-5
  5. Li Q., Mark R.G. & Clifford G.D. "Artificial arterial blood pressure artifact models and an evaluation of a robust blood pressure and heart rate estimator." BioMed Eng OnLine. (2009) 8:13. pmid:19586547 https://doi.org/10.1186/1475-925X-8-13
  6. Park H-J, "VitalDB Python Example Notebooks" GitHub Repository https://github.com/vitaldb/examples/blob/master/hypotension_art.ipynb

Public GitHub Repo (5)¶

  • Publish your code in a public repository on GitHub and attach the URL in the notebook.
  • Make sure your code is documented properly.
    • A README.md file describing the exact steps to run your code is required.
    • Check “ML Code Completeness Checklist” (https://github.com/paperswithcode/releasing-research-code)
    • Check “Best Practices for Reproducibility” (https://www.cs.mcgill.ca/~ksinha4/practices_for_reproducibility/)

Video Presentation (Requirements from Rubrik)¶

Walkthrough of the notebook, no need to make slides. We expect a well-timed, well-presented presentation. You should clearly explain what the original paper is about (what the general problem is, what the specific approach taken was, and what the results claimed were) and what you encountered when you attempted to reproduce the results. You should use the time given to you and not too much (or too little).

  • <= 4 mins
  • Explain the general problem clearly
  • Explain the specific approach taken in the paper clearly
  • Explain reproduction attempts clearly